orca-sdk 0.0.96__py3-none-any.whl → 0.0.98__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. orca_sdk/__init__.py +2 -5
  2. orca_sdk/_shared/__init__.py +1 -0
  3. orca_sdk/_shared/metrics.py +1 -1
  4. orca_sdk/_utils/analysis_ui.py +5 -5
  5. orca_sdk/_utils/auth.py +23 -33
  6. orca_sdk/_utils/pagination.py +126 -0
  7. orca_sdk/_utils/pagination_test.py +132 -0
  8. orca_sdk/classification_model.py +188 -126
  9. orca_sdk/classification_model_test.py +102 -0
  10. orca_sdk/client.py +3515 -0
  11. orca_sdk/conftest.py +10 -0
  12. orca_sdk/credentials.py +73 -21
  13. orca_sdk/credentials_test.py +20 -0
  14. orca_sdk/datasource.py +186 -81
  15. orca_sdk/datasource_test.py +194 -0
  16. orca_sdk/embedding_model.py +267 -75
  17. orca_sdk/embedding_model_test.py +32 -14
  18. orca_sdk/job.py +59 -54
  19. orca_sdk/job_test.py +50 -0
  20. orca_sdk/memoryset.py +372 -345
  21. orca_sdk/memoryset_test.py +7 -11
  22. orca_sdk/regression_model.py +120 -111
  23. orca_sdk/regression_model_test.py +15 -0
  24. orca_sdk/telemetry.py +229 -115
  25. {orca_sdk-0.0.96.dist-info → orca_sdk-0.0.98.dist-info}/METADATA +19 -5
  26. orca_sdk-0.0.98.dist-info/RECORD +40 -0
  27. orca_sdk/_generated_api_client/__init__.py +0 -3
  28. orca_sdk/_generated_api_client/api/__init__.py +0 -287
  29. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  31. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  32. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  33. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  34. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  35. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  36. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +0 -170
  37. orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_classification_model_name_or_id_delete.py +0 -156
  38. orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  39. orca_sdk/_generated_api_client/api/classification_model/evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py +0 -183
  40. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +0 -156
  41. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  42. orca_sdk/_generated_api_client/api/classification_model/list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  43. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +0 -127
  44. orca_sdk/_generated_api_client/api/classification_model/predict_label_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  45. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +0 -183
  46. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  48. orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +0 -183
  49. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  50. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +0 -172
  51. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  52. orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +0 -169
  53. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  54. orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +0 -235
  55. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  57. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  58. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  59. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  60. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  61. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  62. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  64. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/memoryset/analyze_memoryset_memoryset_name_or_id_analysis_post.py +0 -183
  66. orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +0 -168
  67. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  68. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  69. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  70. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  71. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  72. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  73. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -210
  74. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -186
  75. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  76. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -188
  77. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -235
  78. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -180
  79. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -212
  80. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +0 -195
  81. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -210
  82. orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +0 -233
  83. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -216
  84. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -205
  85. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +0 -183
  86. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  87. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +0 -150
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -192
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -161
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/regression_model/create_regression_model_gpu_regression_model_post.py +0 -170
  94. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  95. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_regression_model_name_or_id_delete.py +0 -154
  96. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +0 -183
  97. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  98. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +0 -161
  100. orca_sdk/_generated_api_client/api/regression_model/list_regression_models_regression_model_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +0 -190
  102. orca_sdk/_generated_api_client/api/regression_model/update_regression_model_regression_model_name_or_id_patch.py +0 -183
  103. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  104. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  105. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  106. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +0 -156
  107. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -293
  108. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  109. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +0 -168
  110. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  111. orca_sdk/_generated_api_client/api/telemetry/explain_prediction_telemetry_prediction_prediction_id_explanation_get.py +0 -182
  112. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  113. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -180
  114. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  115. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +0 -198
  116. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -198
  117. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  118. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  119. orca_sdk/_generated_api_client/client.py +0 -216
  120. orca_sdk/_generated_api_client/errors.py +0 -38
  121. orca_sdk/_generated_api_client/models/__init__.py +0 -295
  122. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -116
  123. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -137
  124. orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +0 -9
  125. orca_sdk/_generated_api_client/models/base_label_prediction_result.py +0 -130
  126. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  127. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +0 -108
  128. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -207
  129. orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +0 -154
  130. orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +0 -92
  131. orca_sdk/_generated_api_client/models/classification_evaluation_request.py +0 -148
  132. orca_sdk/_generated_api_client/models/classification_metrics.py +0 -259
  133. orca_sdk/_generated_api_client/models/classification_model_metadata.py +0 -213
  134. orca_sdk/_generated_api_client/models/classification_prediction_request.py +0 -220
  135. orca_sdk/_generated_api_client/models/clone_memoryset_request.py +0 -170
  136. orca_sdk/_generated_api_client/models/cluster_metrics.py +0 -78
  137. orca_sdk/_generated_api_client/models/column_info.py +0 -145
  138. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  139. orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +0 -80
  140. orca_sdk/_generated_api_client/models/count_predictions_request.py +0 -195
  141. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -120
  142. orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +0 -9
  143. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -145
  144. orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +0 -9
  145. orca_sdk/_generated_api_client/models/create_classification_model_request.py +0 -197
  146. orca_sdk/_generated_api_client/models/create_memoryset_request.py +0 -325
  147. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +0 -66
  148. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +0 -13
  149. orca_sdk/_generated_api_client/models/create_regression_model_request.py +0 -137
  150. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -156
  151. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  152. orca_sdk/_generated_api_client/models/delete_memorysets_request.py +0 -70
  153. orca_sdk/_generated_api_client/models/embed_request.py +0 -135
  154. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +0 -187
  155. orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +0 -179
  156. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +0 -158
  157. orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +0 -86
  158. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  159. orca_sdk/_generated_api_client/models/embedding_model_result.py +0 -114
  160. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -153
  161. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +0 -140
  162. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +0 -140
  163. orca_sdk/_generated_api_client/models/feedback_metrics.py +0 -85
  164. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  165. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  166. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -17
  167. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -20
  168. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  169. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  170. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  171. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  172. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  173. orca_sdk/_generated_api_client/models/http_validation_error.py +0 -86
  174. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  175. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  176. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -210
  177. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  178. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -288
  179. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -186
  180. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -128
  181. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  182. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -194
  183. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  184. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  185. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -246
  186. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  187. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  188. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +0 -207
  189. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py +0 -68
  190. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py +0 -68
  191. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  192. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -319
  193. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  194. orca_sdk/_generated_api_client/models/lookup_score_metrics.py +0 -94
  195. orca_sdk/_generated_api_client/models/memory_metrics.py +0 -165
  196. orca_sdk/_generated_api_client/models/memory_type.py +0 -9
  197. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -212
  198. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -105
  199. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -182
  200. orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config.py +0 -202
  201. orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_clustering_method.py +0 -9
  202. orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_partitioning_method.py +0 -10
  203. orca_sdk/_generated_api_client/models/memoryset_cluster_metrics.py +0 -100
  204. orca_sdk/_generated_api_client/models/memoryset_duplicate_analysis_config.py +0 -70
  205. orca_sdk/_generated_api_client/models/memoryset_duplicate_metrics.py +0 -70
  206. orca_sdk/_generated_api_client/models/memoryset_label_analysis_config.py +0 -70
  207. orca_sdk/_generated_api_client/models/memoryset_label_metrics.py +0 -116
  208. orca_sdk/_generated_api_client/models/memoryset_metadata.py +0 -291
  209. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +0 -55
  210. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +0 -13
  211. orca_sdk/_generated_api_client/models/memoryset_metrics.py +0 -232
  212. orca_sdk/_generated_api_client/models/memoryset_neighbor_analysis_config.py +0 -83
  213. orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics.py +0 -76
  214. orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics_lookup_score_metrics.py +0 -68
  215. orca_sdk/_generated_api_client/models/memoryset_projection_analysis_config.py +0 -79
  216. orca_sdk/_generated_api_client/models/memoryset_projection_metrics.py +0 -55
  217. orca_sdk/_generated_api_client/models/memoryset_update.py +0 -101
  218. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  219. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -22
  220. orca_sdk/_generated_api_client/models/paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py +0 -135
  221. orca_sdk/_generated_api_client/models/pr_curve.py +0 -86
  222. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  223. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  224. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  225. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  226. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +0 -10
  227. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +0 -9
  228. orca_sdk/_generated_api_client/models/predictive_model_update.py +0 -91
  229. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -107
  230. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -17
  231. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  232. orca_sdk/_generated_api_client/models/rar_head_type.py +0 -8
  233. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +0 -148
  234. orca_sdk/_generated_api_client/models/regression_metrics.py +0 -172
  235. orca_sdk/_generated_api_client/models/regression_model_metadata.py +0 -177
  236. orca_sdk/_generated_api_client/models/regression_prediction_request.py +0 -195
  237. orca_sdk/_generated_api_client/models/roc_curve.py +0 -86
  238. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +0 -196
  239. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +0 -68
  240. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +0 -252
  241. orca_sdk/_generated_api_client/models/scored_memory.py +0 -172
  242. orca_sdk/_generated_api_client/models/scored_memory_insert.py +0 -128
  243. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +0 -68
  244. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +0 -180
  245. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +0 -68
  246. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +0 -68
  247. orca_sdk/_generated_api_client/models/scored_memory_update.py +0 -171
  248. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +0 -68
  249. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +0 -193
  250. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +0 -68
  251. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +0 -68
  252. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  253. orca_sdk/_generated_api_client/models/task.py +0 -198
  254. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  255. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  256. orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_2.py +0 -9
  257. orca_sdk/_generated_api_client/models/telemetry_filter_item.py +0 -205
  258. orca_sdk/_generated_api_client/models/telemetry_filter_item_op.py +0 -15
  259. orca_sdk/_generated_api_client/models/telemetry_memories_request.py +0 -181
  260. orca_sdk/_generated_api_client/models/telemetry_sort_options.py +0 -173
  261. orca_sdk/_generated_api_client/models/telemetry_sort_options_direction.py +0 -9
  262. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  263. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  264. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -133
  265. orca_sdk/_generated_api_client/models/validation_error.py +0 -99
  266. orca_sdk/_generated_api_client/py.typed +0 -1
  267. orca_sdk/_generated_api_client/types.py +0 -56
  268. orca_sdk-0.0.96.dist-info/RECORD +0 -278
  269. {orca_sdk-0.0.96.dist-info → orca_sdk-0.0.98.dist-info}/WHEEL +0 -0
@@ -1,56 +1,61 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
4
  from contextlib import contextmanager
6
5
  from datetime import datetime
7
6
  from typing import Any, Generator, Iterable, Literal, cast, overload
8
- from uuid import UUID
9
7
 
10
8
  from datasets import Dataset
11
9
 
12
- from ._generated_api_client.api import (
13
- create_classification_model_gpu,
14
- delete_classification_model,
15
- evaluate_classification_model,
16
- get_classification_model,
17
- get_classification_model_evaluation,
18
- list_classification_models,
19
- list_predictions,
20
- predict_label_gpu,
21
- record_prediction_feedback,
22
- update_classification_model,
23
- )
24
- from ._generated_api_client.models import (
25
- ClassificationEvaluationRequest,
10
+ from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
11
+ from ._utils.common import UNSET, CreateMode, DropMode
12
+ from .client import (
13
+ BootstrapClassificationModelMeta,
14
+ BootstrapClassificationModelResult,
26
15
  ClassificationModelMetadata,
27
- ClassificationPredictionRequest,
28
- CreateClassificationModelRequest,
29
- LabelPredictionWithMemoriesAndFeedback,
30
- ListPredictionsRequest,
31
- )
32
- from ._generated_api_client.models import (
33
- PredictionSortItemItemType0 as PredictionSortColumns,
34
- )
35
- from ._generated_api_client.models import (
36
- PredictionSortItemItemType1 as PredictionSortDirection,
37
- )
38
- from ._generated_api_client.models import (
39
16
  PredictiveModelUpdate,
40
17
  RACHeadType,
18
+ orca_api,
41
19
  )
42
- from ._generated_api_client.types import UNSET as CLIENT_UNSET
43
- from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
44
- from ._utils.common import UNSET, CreateMode, DropMode
45
20
  from .datasource import Datasource
46
21
  from .job import Job
47
22
  from .memoryset import (
48
23
  FilterItem,
49
24
  FilterItemTuple,
50
25
  LabeledMemoryset,
26
+ _is_metric_column,
51
27
  _parse_filter_item_from_tuple,
52
28
  )
53
- from .telemetry import ClassificationPrediction, _parse_feedback
29
+ from .telemetry import (
30
+ ClassificationPrediction,
31
+ TelemetryMode,
32
+ _get_telemetry_config,
33
+ _parse_feedback,
34
+ )
35
+
36
+
37
+ class BootstrappedClassificationModel:
38
+
39
+ datasource: Datasource | None
40
+ memoryset: LabeledMemoryset | None
41
+ classification_model: ClassificationModel | None
42
+ agent_output: BootstrapClassificationModelResult | None
43
+
44
+ def __init__(self, metadata: BootstrapClassificationModelMeta):
45
+ self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
46
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_meta"]["id"])
47
+ self.classification_model = ClassificationModel.open(metadata["model_meta"]["id"])
48
+ self.agent_output = metadata["agent_output"]
49
+
50
+ def __repr__(self):
51
+ return (
52
+ "BootstrappedClassificationModel({\n"
53
+ f" datasource: {self.datasource},\n"
54
+ f" memoryset: {self.memoryset},\n"
55
+ f" classification_model: {self.classification_model},\n"
56
+ f" agent_output: {self.agent_output},\n"
57
+ "})"
58
+ )
54
59
 
55
60
 
56
61
  class ClassificationModel:
@@ -86,18 +91,18 @@ class ClassificationModel:
86
91
 
87
92
  def __init__(self, metadata: ClassificationModelMetadata):
88
93
  # for internal use only, do not document
89
- self.id = metadata.id
90
- self.name = metadata.name
91
- self.description = metadata.description
92
- self.memoryset = LabeledMemoryset.open(metadata.memoryset_id)
93
- self.head_type = metadata.head_type
94
- self.num_classes = metadata.num_classes
95
- self.memory_lookup_count = metadata.memory_lookup_count
96
- self.weigh_memories = metadata.weigh_memories
97
- self.min_memory_weight = metadata.min_memory_weight
98
- self.version = metadata.version
99
- self.locked = metadata.locked
100
- self.created_at = metadata.created_at
94
+ self.id = metadata["id"]
95
+ self.name = metadata["name"]
96
+ self.description = metadata["description"]
97
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_id"])
98
+ self.head_type = metadata["head_type"]
99
+ self.num_classes = metadata["num_classes"]
100
+ self.memory_lookup_count = metadata["memory_lookup_count"]
101
+ self.weigh_memories = metadata["weigh_memories"]
102
+ self.min_memory_weight = metadata["min_memory_weight"]
103
+ self.version = metadata["version"]
104
+ self.locked = metadata["locked"]
105
+ self.created_at = datetime.fromisoformat(metadata["created_at"])
101
106
 
102
107
  self._memoryset_override_id: str | None = None
103
108
  self._last_prediction: ClassificationPrediction | None = None
@@ -140,7 +145,7 @@ class ClassificationModel:
140
145
  cls,
141
146
  name: str,
142
147
  memoryset: LabeledMemoryset,
143
- head_type: Literal["BMMOE", "FF", "KNN", "MMOE"] = "KNN",
148
+ head_type: RACHeadType = "KNN",
144
149
  *,
145
150
  description: str | None = None,
146
151
  num_classes: int | None = None,
@@ -206,17 +211,18 @@ class ClassificationModel:
206
211
 
207
212
  return existing
208
213
 
209
- metadata = create_classification_model_gpu(
210
- body=CreateClassificationModelRequest(
211
- name=name,
212
- memoryset_id=memoryset.id,
213
- head_type=RACHeadType(head_type),
214
- memory_lookup_count=memory_lookup_count,
215
- num_classes=num_classes,
216
- weigh_memories=weigh_memories,
217
- min_memory_weight=min_memory_weight,
218
- description=description,
219
- ),
214
+ metadata = orca_api.POST(
215
+ "/classification_model",
216
+ json={
217
+ "name": name,
218
+ "memoryset_name_or_id": memoryset.id,
219
+ "head_type": head_type,
220
+ "memory_lookup_count": memory_lookup_count,
221
+ "num_classes": num_classes,
222
+ "weigh_memories": weigh_memories,
223
+ "min_memory_weight": min_memory_weight,
224
+ "description": description,
225
+ },
220
226
  )
221
227
  return cls(metadata)
222
228
 
@@ -234,7 +240,7 @@ class ClassificationModel:
234
240
  Raises:
235
241
  LookupError: If the classification model does not exist
236
242
  """
237
- return cls(get_classification_model(name))
243
+ return cls(orca_api.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
238
244
 
239
245
  @classmethod
240
246
  def exists(cls, name_or_id: str) -> bool:
@@ -261,7 +267,7 @@ class ClassificationModel:
261
267
  Returns:
262
268
  List of handles to all classification models in the OrcaCloud
263
269
  """
264
- return [cls(metadata) for metadata in list_classification_models()]
270
+ return [cls(metadata) for metadata in orca_api.GET("/classification_model")]
265
271
 
266
272
  @classmethod
267
273
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -280,7 +286,7 @@ class ClassificationModel:
280
286
  LookupError: If the classification model does not exist and if_not_exists is `"error"`
281
287
  """
282
288
  try:
283
- delete_classification_model(name_or_id)
289
+ orca_api.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
284
290
  logging.info(f"Deleted model {name_or_id}")
285
291
  except LookupError:
286
292
  if if_not_exists == "error":
@@ -311,11 +317,12 @@ class ClassificationModel:
311
317
  Lock the model:
312
318
  >>> model.set(locked=True)
313
319
  """
314
- update_data = PredictiveModelUpdate(
315
- description=CLIENT_UNSET if description is UNSET else description,
316
- locked=CLIENT_UNSET if locked is UNSET else locked,
317
- )
318
- update_classification_model(self.id, body=update_data)
320
+ update: PredictiveModelUpdate = {}
321
+ if description is not UNSET:
322
+ update["description"] = description
323
+ if locked is not UNSET:
324
+ update["locked"] = locked
325
+ orca_api.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
319
326
  self.refresh()
320
327
 
321
328
  def lock(self) -> None:
@@ -333,7 +340,9 @@ class ClassificationModel:
333
340
  expected_labels: list[int] | None = None,
334
341
  filters: list[FilterItemTuple] = [],
335
342
  tags: set[str] | None = None,
336
- save_telemetry: Literal["off", "on", "sync", "async"] = "on",
343
+ save_telemetry: TelemetryMode = "on",
344
+ prompt: str | None = None,
345
+ use_lookup_cache: bool = True,
337
346
  ) -> list[ClassificationPrediction]:
338
347
  pass
339
348
 
@@ -344,17 +353,21 @@ class ClassificationModel:
344
353
  expected_labels: int | None = None,
345
354
  filters: list[FilterItemTuple] = [],
346
355
  tags: set[str] | None = None,
347
- save_telemetry: Literal["off", "on", "sync", "async"] = "on",
356
+ save_telemetry: TelemetryMode = "on",
357
+ prompt: str | None = None,
358
+ use_lookup_cache: bool = True,
348
359
  ) -> ClassificationPrediction:
349
360
  pass
350
361
 
351
362
  def predict(
352
363
  self,
353
364
  value: list[str] | str,
354
- expected_labels: list[int] | int | None = None,
365
+ expected_labels: list[int] | list[str] | int | str | None = None,
355
366
  filters: list[FilterItemTuple] = [],
356
367
  tags: set[str] | None = None,
357
- save_telemetry: Literal["off", "on", "sync", "async"] = "on",
368
+ save_telemetry: TelemetryMode = "on",
369
+ prompt: str | None = None,
370
+ use_lookup_cache: bool = True,
358
371
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
359
372
  """
360
373
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -362,6 +375,7 @@ class ClassificationModel:
362
375
  Params:
363
376
  value: Value(s) to get predict the labels of
364
377
  expected_labels: Expected label(s) for the given input to record for model evaluation
378
+ filters: Optional filters to apply during memory lookup
365
379
  tags: Tags to add to the prediction(s)
366
380
  save_telemetry: Whether to save telemetry for the prediction(s). One of
367
381
  * `"off"`: Do not save telemetry
@@ -369,6 +383,7 @@ class ClassificationModel:
369
383
  environment variable is set.
370
384
  * `"sync"`: Save telemetry synchronously
371
385
  * `"async"`: Save telemetry asynchronously
386
+ prompt: Optional prompt to use for instruction-tuned embedding models
372
387
 
373
388
  Returns:
374
389
  Label prediction or list of label predictions
@@ -384,48 +399,60 @@ class ClassificationModel:
384
399
  ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
385
400
  ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
386
401
  ]
402
+
403
+ Using a prompt with an instruction-tuned embedding model:
404
+ >>> prediction = model.predict("I am happy", prompt="Represent this text for sentiment classification:")
405
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
387
406
  """
388
407
 
389
408
  parsed_filters = [
390
409
  _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
391
410
  ]
392
411
 
393
- if not all(isinstance(filter, FilterItem) for filter in parsed_filters):
412
+ if any(_is_metric_column(filter[0]) for filter in filters):
394
413
  raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
395
414
 
396
- response = predict_label_gpu(
397
- self.id,
398
- body=ClassificationPredictionRequest(
399
- input_values=value if isinstance(value, list) else [value],
400
- memoryset_override_id=self._memoryset_override_id,
401
- expected_labels=(
402
- expected_labels
403
- if isinstance(expected_labels, list)
404
- else [expected_labels] if expected_labels is not None else None
405
- ),
406
- tags=list(tags or set()),
407
- save_telemetry=save_telemetry != "off",
408
- save_telemetry_synchronously=(
409
- os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or save_telemetry == "sync"
410
- ),
411
- filters=cast(list[FilterItem], parsed_filters),
412
- ),
415
+ if isinstance(expected_labels, int):
416
+ expected_labels = [expected_labels]
417
+ elif isinstance(expected_labels, str):
418
+ expected_labels = [self.memoryset.label_names.index(expected_labels)]
419
+ elif isinstance(expected_labels, list):
420
+ expected_labels = [
421
+ self.memoryset.label_names.index(label) if isinstance(label, str) else label
422
+ for label in expected_labels
423
+ ]
424
+
425
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
426
+ response = orca_api.POST(
427
+ "/gpu/classification_model/{name_or_id}/prediction",
428
+ params={"name_or_id": self.id},
429
+ json={
430
+ "input_values": value if isinstance(value, list) else [value],
431
+ "memoryset_override_name_or_id": self._memoryset_override_id,
432
+ "expected_labels": expected_labels,
433
+ "tags": list(tags or set()),
434
+ "save_telemetry": telemetry_on,
435
+ "save_telemetry_synchronously": telemetry_sync,
436
+ "filters": cast(list[FilterItem], parsed_filters),
437
+ "prompt": prompt,
438
+ "use_lookup_cache": use_lookup_cache,
439
+ },
413
440
  )
414
441
 
415
- if save_telemetry != "off" and any(p.prediction_id is None for p in response):
442
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
416
443
  raise RuntimeError("Failed to save prediction to database.")
417
444
 
418
445
  predictions = [
419
446
  ClassificationPrediction(
420
- prediction_id=prediction.prediction_id,
421
- label=prediction.label,
422
- label_name=prediction.label_name,
447
+ prediction_id=prediction["prediction_id"],
448
+ label=prediction["label"],
449
+ label_name=prediction["label_name"],
423
450
  score=None,
424
- confidence=prediction.confidence,
425
- anomaly_score=prediction.anomaly_score,
451
+ confidence=prediction["confidence"],
452
+ anomaly_score=prediction["anomaly_score"],
426
453
  memoryset=self.memoryset,
427
454
  model=self,
428
- logits=prediction.logits,
455
+ logits=prediction["logits"],
429
456
  input_value=input_value,
430
457
  )
431
458
  for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
@@ -439,7 +466,7 @@ class ClassificationModel:
439
466
  limit: int = 100,
440
467
  offset: int = 0,
441
468
  tag: str | None = None,
442
- sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
469
+ sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
443
470
  expected_label_match: bool | None = None,
444
471
  ) -> list[ClassificationPrediction]:
445
472
  """
@@ -475,30 +502,31 @@ class ClassificationModel:
475
502
  >>> predictions = model.predictions(expected_label_match=False)
476
503
  [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
477
504
  """
478
- predictions = list_predictions(
479
- body=ListPredictionsRequest(
480
- model_id=self.id,
481
- limit=limit,
482
- offset=offset,
483
- sort=cast(list[list[PredictionSortColumns | PredictionSortDirection]], sort),
484
- tag=tag,
485
- expected_label_match=expected_label_match,
486
- ),
505
+ predictions = orca_api.POST(
506
+ "/telemetry/prediction",
507
+ json={
508
+ "model_id": self.id,
509
+ "limit": limit,
510
+ "offset": offset,
511
+ "sort": [list(sort_item) for sort_item in sort],
512
+ "tag": tag,
513
+ "expected_label_match": expected_label_match,
514
+ },
487
515
  )
488
516
  return [
489
517
  ClassificationPrediction(
490
- prediction_id=prediction.prediction_id,
491
- label=prediction.label,
492
- label_name=prediction.label_name,
518
+ prediction_id=prediction["prediction_id"],
519
+ label=prediction["label"],
520
+ label_name=prediction["label_name"],
493
521
  score=None,
494
- confidence=prediction.confidence,
495
- anomaly_score=prediction.anomaly_score,
522
+ confidence=prediction["confidence"],
523
+ anomaly_score=prediction["anomaly_score"],
496
524
  memoryset=self.memoryset,
497
525
  model=self,
498
526
  telemetry=prediction,
499
527
  )
500
528
  for prediction in predictions
501
- if isinstance(prediction, LabelPredictionWithMemoriesAndFeedback)
529
+ if "label" in prediction
502
530
  ]
503
531
 
504
532
  def _evaluate_datasource(
@@ -510,23 +538,28 @@ class ClassificationModel:
510
538
  tags: set[str] | None,
511
539
  background: bool = False,
512
540
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
513
- response = evaluate_classification_model(
514
- self.id,
515
- body=ClassificationEvaluationRequest(
516
- datasource_id=datasource.id,
517
- datasource_label_column=label_column,
518
- datasource_value_column=value_column,
519
- memoryset_override_id=self._memoryset_override_id,
520
- record_telemetry=record_predictions,
521
- telemetry_tags=list(tags) if tags else None,
522
- ),
541
+ response = orca_api.POST(
542
+ "/classification_model/{model_name_or_id}/evaluation",
543
+ params={"model_name_or_id": self.id},
544
+ json={
545
+ "datasource_name_or_id": datasource.id,
546
+ "datasource_label_column": label_column,
547
+ "datasource_value_column": value_column,
548
+ "memoryset_override_name_or_id": self._memoryset_override_id,
549
+ "record_telemetry": record_predictions,
550
+ "telemetry_tags": list(tags) if tags else None,
551
+ },
523
552
  )
524
553
 
525
- job = Job(
526
- response.task_id,
527
- lambda: (r := get_classification_model_evaluation(self.id, UUID(response.task_id)).result)
528
- and ClassificationMetrics(**r.to_dict()),
529
- )
554
+ def get_value():
555
+ res = orca_api.GET(
556
+ "/classification_model/{model_name_or_id}/evaluation/{task_id}",
557
+ params={"model_name_or_id": self.id, "task_id": response["task_id"]},
558
+ )
559
+ assert res["result"] is not None
560
+ return ClassificationMetrics(**res["result"])
561
+
562
+ job = Job(response["task_id"], get_value)
530
563
  return job if background else job.result()
531
564
 
532
565
  def _evaluate_dataset(
@@ -709,8 +742,37 @@ class ClassificationModel:
709
742
  ValueError: If the value does not match previous value types for the category, or is a
710
743
  [`float`][float] that is not between `-1.0` and `+1.0`.
711
744
  """
712
- record_prediction_feedback(
713
- body=[
745
+ orca_api.PUT(
746
+ "/telemetry/prediction/feedback",
747
+ json=[
714
748
  _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
715
749
  ],
716
750
  )
751
+
752
+ @staticmethod
753
+ def bootstrap_model(
754
+ model_description: str,
755
+ label_names: list[str],
756
+ initial_examples: list[tuple[str, str]],
757
+ num_examples_per_label: int,
758
+ background: bool = False,
759
+ ) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
760
+ response = orca_api.POST(
761
+ "/agents/bootstrap_classification_model",
762
+ json={
763
+ "model_description": model_description,
764
+ "label_names": label_names,
765
+ "initial_examples": [{"text": text, "label_name": label_name} for text, label_name in initial_examples],
766
+ "num_examples_per_label": num_examples_per_label,
767
+ },
768
+ )
769
+
770
+ def get_result() -> BootstrappedClassificationModel:
771
+ res = orca_api.GET(
772
+ "/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
773
+ )
774
+ assert res["result"] is not None
775
+ return BootstrappedClassificationModel(res["result"])
776
+
777
+ job = Job(response["task_id"], get_result)
778
+ return job if background else job.result()
@@ -326,6 +326,37 @@ def test_predict_with_filters(classification_model: ClassificationModel):
326
326
  assert filtered_prediction.label_name == "cats"
327
327
 
328
328
 
329
+ def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
330
+ model = ClassificationModel.create(
331
+ "test_predict_with_memoryset_update",
332
+ writable_memoryset,
333
+ num_classes=2,
334
+ memory_lookup_count=3,
335
+ )
336
+
337
+ prediction = model.predict("Do you love soup?")
338
+ assert prediction.label == 0
339
+ assert prediction.label_name == "soup"
340
+
341
+ # insert new memories
342
+ writable_memoryset.insert(
343
+ [
344
+ {"value": "Do you love soup?", "label": 1, "key": "g1"},
345
+ {"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
346
+ {"value": "Do you love crackers?", "label": 1, "key": "g2"},
347
+ {"value": "Do you love broth?", "label": 1, "key": "g2"},
348
+ {"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
349
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
350
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
351
+ ],
352
+ )
353
+ prediction = model.predict("Do you love soup?")
354
+ assert prediction.label == 1
355
+ assert prediction.label_name == "cats"
356
+
357
+ ClassificationModel.drop("test_predict_with_memoryset_update")
358
+
359
+
329
360
  def test_last_prediction_with_batch(classification_model: ClassificationModel):
330
361
  predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
331
362
  assert classification_model.last_prediction is not None
@@ -374,3 +405,74 @@ def test_explain(writable_memoryset: LabeledMemoryset):
374
405
  raise e
375
406
  finally:
376
407
  ClassificationModel.drop("test_model_for_explain")
408
+
409
+
410
+ @skip_in_ci("We don't have Anthropic API key in CI")
411
+ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
412
+ """Test getting action recommendations for predictions"""
413
+
414
+ writable_memoryset.analyze(
415
+ {"name": "neighbor", "neighbor_counts": [1, 3]},
416
+ lookup_count=3,
417
+ )
418
+
419
+ model = ClassificationModel.create(
420
+ "test_model_for_action",
421
+ writable_memoryset,
422
+ num_classes=2,
423
+ memory_lookup_count=3,
424
+ description="This is a test model for action recommendations",
425
+ )
426
+
427
+ # Make a prediction with expected label to simulate incorrect prediction
428
+ prediction = model.predict("Do you love soup?", expected_labels=1)
429
+
430
+ memoryset_length = model.memoryset.length
431
+
432
+ try:
433
+ # Get action recommendation
434
+ action, rationale = prediction.recommend_action()
435
+
436
+ assert action is not None
437
+ assert rationale is not None
438
+ assert action in ["remove_duplicates", "detect_mislabels", "add_memories", "finetuning"]
439
+ assert len(rationale) > 10
440
+
441
+ # Test memory suggestions
442
+ suggestions_response = prediction.generate_memory_suggestions(num_memories=2)
443
+ memory_suggestions = suggestions_response.suggestions
444
+
445
+ assert memory_suggestions is not None
446
+ assert len(memory_suggestions) == 2
447
+
448
+ for suggestion in memory_suggestions:
449
+ assert isinstance(suggestion[0], str)
450
+ assert len(suggestion[0]) > 0
451
+ assert isinstance(suggestion[1], str)
452
+ assert suggestion[1] in model.memoryset.label_names
453
+
454
+ suggestions_response.apply()
455
+
456
+ model.memoryset.refresh()
457
+ assert model.memoryset.length == memoryset_length + 2
458
+
459
+ except Exception as e:
460
+ if "ANTHROPIC_API_KEY" in str(e):
461
+ logging.info("Skipping agent tests because ANTHROPIC_API_KEY is not set")
462
+ else:
463
+ raise e
464
+ finally:
465
+ ClassificationModel.drop("test_model_for_action")
466
+
467
+
468
+ def test_predict_with_prompt(classification_model: ClassificationModel):
469
+ """Test that prompt parameter is properly passed through to predictions"""
470
+ # Test with an instruction-supporting embedding model if available
471
+ prediction_with_prompt = classification_model.predict(
472
+ "I love this product!", prompt="Represent this text for sentiment classification:"
473
+ )
474
+ prediction_without_prompt = classification_model.predict("I love this product!")
475
+
476
+ # Both should work and return valid predictions
477
+ assert prediction_with_prompt.label is not None
478
+ assert prediction_without_prompt.label is not None