orca-sdk 0.0.92__py3-none-any.whl → 0.0.94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. orca_sdk/_generated_api_client/api/__init__.py +8 -0
  2. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +148 -0
  3. orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
  4. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
  5. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
  6. orca_sdk/_generated_api_client/models/__init__.py +10 -0
  7. orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
  8. orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
  9. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
  10. orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
  11. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
  12. orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
  13. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
  14. orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
  15. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
  16. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
  17. orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
  18. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
  19. orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
  20. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
  21. orca_sdk/_generated_api_client/models/validation_error.py +99 -0
  22. orca_sdk/_utils/data_parsing.py +31 -2
  23. orca_sdk/_utils/data_parsing_test.py +18 -15
  24. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  25. orca_sdk/classification_model.py +32 -12
  26. orca_sdk/classification_model_test.py +95 -34
  27. orca_sdk/conftest.py +87 -25
  28. orca_sdk/datasource.py +56 -12
  29. orca_sdk/datasource_test.py +9 -0
  30. orca_sdk/embedding_model_test.py +6 -5
  31. orca_sdk/memoryset.py +78 -0
  32. orca_sdk/memoryset_test.py +199 -123
  33. orca_sdk/telemetry.py +5 -3
  34. {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/METADATA +1 -1
  35. {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/RECORD +36 -28
  36. {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/WHEEL +0 -0
@@ -0,0 +1,86 @@
1
+ """
2
+ This file is generated by the openapi-python-client tool via the generate_api_client.py script
3
+
4
+ It is a customized template from the openapi-python-client tool's default template:
5
+ https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
6
+
7
+ The main change is:
8
+ - Fix typing issues
9
+ """
10
+
11
+ # flake8: noqa: C901
12
+
13
+ from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
14
+
15
+ from attrs import define as _attrs_define
16
+ from attrs import field as _attrs_field
17
+
18
+ from ..types import UNSET, Unset
19
+
20
+ if TYPE_CHECKING:
21
+ from ..models.validation_error import ValidationError
22
+
23
+
24
+ T = TypeVar("T", bound="HTTPValidationError")
25
+
26
+
27
+ @_attrs_define
28
+ class HTTPValidationError:
29
+ """
30
+ Attributes:
31
+ detail (Union[Unset, List['ValidationError']]):
32
+ """
33
+
34
+ detail: Union[Unset, List["ValidationError"]] = UNSET
35
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
36
+
37
+ def to_dict(self) -> dict[str, Any]:
38
+ detail: Union[Unset, List[Dict[str, Any]]] = UNSET
39
+ if not isinstance(self.detail, Unset):
40
+ detail = []
41
+ for detail_item_data in self.detail:
42
+ detail_item = detail_item_data.to_dict()
43
+ detail.append(detail_item)
44
+
45
+ field_dict: dict[str, Any] = {}
46
+ field_dict.update(self.additional_properties)
47
+ field_dict.update({})
48
+ if detail is not UNSET:
49
+ field_dict["detail"] = detail
50
+
51
+ return field_dict
52
+
53
+ @classmethod
54
+ def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
55
+ from ..models.validation_error import ValidationError
56
+
57
+ d = src_dict.copy()
58
+ detail = []
59
+ _detail = d.pop("detail", UNSET)
60
+ for detail_item_data in _detail or []:
61
+ detail_item = ValidationError.from_dict(detail_item_data)
62
+
63
+ detail.append(detail_item)
64
+
65
+ http_validation_error = cls(
66
+ detail=detail,
67
+ )
68
+
69
+ http_validation_error.additional_properties = d
70
+ return http_validation_error
71
+
72
+ @property
73
+ def additional_keys(self) -> list[str]:
74
+ return list(self.additional_properties.keys())
75
+
76
+ def __getitem__(self, key: str) -> Any:
77
+ return self.additional_properties[key]
78
+
79
+ def __setitem__(self, key: str, value: Any) -> None:
80
+ self.additional_properties[key] = value
81
+
82
+ def __delitem__(self, key: str) -> None:
83
+ del self.additional_properties[key]
84
+
85
+ def __contains__(self, key: str) -> bool:
86
+ return key in self.additional_properties
@@ -38,6 +38,7 @@ class LabelPredictionMemoryLookup:
38
38
  memory_version (int):
39
39
  created_at (datetime.datetime):
40
40
  updated_at (datetime.datetime):
41
+ edited_at (datetime.datetime):
41
42
  metrics (MemoryMetrics):
42
43
  label (int):
43
44
  label_name (Union[None, str]):
@@ -54,6 +55,7 @@ class LabelPredictionMemoryLookup:
54
55
  memory_version: int
55
56
  created_at: datetime.datetime
56
57
  updated_at: datetime.datetime
58
+ edited_at: datetime.datetime
57
59
  metrics: "MemoryMetrics"
58
60
  label: int
59
61
  label_name: Union[None, str]
@@ -81,6 +83,8 @@ class LabelPredictionMemoryLookup:
81
83
 
82
84
  updated_at = self.updated_at.isoformat()
83
85
 
86
+ edited_at = self.edited_at.isoformat()
87
+
84
88
  metrics = self.metrics.to_dict()
85
89
 
86
90
  label = self.label
@@ -106,6 +110,7 @@ class LabelPredictionMemoryLookup:
106
110
  "memory_version": memory_version,
107
111
  "created_at": created_at,
108
112
  "updated_at": updated_at,
113
+ "edited_at": edited_at,
109
114
  "metrics": metrics,
110
115
  "label": label,
111
116
  "label_name": label_name,
@@ -148,6 +153,8 @@ class LabelPredictionMemoryLookup:
148
153
 
149
154
  updated_at = isoparse(d.pop("updated_at"))
150
155
 
156
+ edited_at = isoparse(d.pop("edited_at"))
157
+
151
158
  metrics = MemoryMetrics.from_dict(d.pop("metrics"))
152
159
 
153
160
  label = d.pop("label")
@@ -174,6 +181,7 @@ class LabelPredictionMemoryLookup:
174
181
  memory_version=memory_version,
175
182
  created_at=created_at,
176
183
  updated_at=updated_at,
184
+ edited_at=edited_at,
177
185
  metrics=metrics,
178
186
  label=label,
179
187
  label_name=label_name,
@@ -38,6 +38,7 @@ class LabeledMemory:
38
38
  memory_version (int):
39
39
  created_at (datetime.datetime):
40
40
  updated_at (datetime.datetime):
41
+ edited_at (datetime.datetime):
41
42
  metrics (LabeledMemoryMetrics): Metrics computed for a labeled memory.
42
43
  label (int):
43
44
  label_name (Union[None, str]):
@@ -51,6 +52,7 @@ class LabeledMemory:
51
52
  memory_version: int
52
53
  created_at: datetime.datetime
53
54
  updated_at: datetime.datetime
55
+ edited_at: datetime.datetime
54
56
  metrics: "LabeledMemoryMetrics"
55
57
  label: int
56
58
  label_name: Union[None, str]
@@ -75,6 +77,8 @@ class LabeledMemory:
75
77
 
76
78
  updated_at = self.updated_at.isoformat()
77
79
 
80
+ edited_at = self.edited_at.isoformat()
81
+
78
82
  metrics = self.metrics.to_dict()
79
83
 
80
84
  label = self.label
@@ -94,6 +98,7 @@ class LabeledMemory:
94
98
  "memory_version": memory_version,
95
99
  "created_at": created_at,
96
100
  "updated_at": updated_at,
101
+ "edited_at": edited_at,
97
102
  "metrics": metrics,
98
103
  "label": label,
99
104
  "label_name": label_name,
@@ -133,6 +138,8 @@ class LabeledMemory:
133
138
 
134
139
  updated_at = isoparse(d.pop("updated_at"))
135
140
 
141
+ edited_at = isoparse(d.pop("edited_at"))
142
+
136
143
  metrics = LabeledMemoryMetrics.from_dict(d.pop("metrics"))
137
144
 
138
145
  label = d.pop("label")
@@ -153,6 +160,7 @@ class LabeledMemory:
153
160
  memory_version=memory_version,
154
161
  created_at=created_at,
155
162
  updated_at=updated_at,
163
+ edited_at=edited_at,
156
164
  metrics=metrics,
157
165
  label=label,
158
166
  label_name=label_name,
@@ -38,6 +38,7 @@ class LabeledMemoryLookup:
38
38
  memory_version (int):
39
39
  created_at (datetime.datetime):
40
40
  updated_at (datetime.datetime):
41
+ edited_at (datetime.datetime):
41
42
  metrics (MemoryMetrics):
42
43
  label (int):
43
44
  label_name (Union[None, str]):
@@ -52,6 +53,7 @@ class LabeledMemoryLookup:
52
53
  memory_version: int
53
54
  created_at: datetime.datetime
54
55
  updated_at: datetime.datetime
56
+ edited_at: datetime.datetime
55
57
  metrics: "MemoryMetrics"
56
58
  label: int
57
59
  label_name: Union[None, str]
@@ -77,6 +79,8 @@ class LabeledMemoryLookup:
77
79
 
78
80
  updated_at = self.updated_at.isoformat()
79
81
 
82
+ edited_at = self.edited_at.isoformat()
83
+
80
84
  metrics = self.metrics.to_dict()
81
85
 
82
86
  label = self.label
@@ -98,6 +102,7 @@ class LabeledMemoryLookup:
98
102
  "memory_version": memory_version,
99
103
  "created_at": created_at,
100
104
  "updated_at": updated_at,
105
+ "edited_at": edited_at,
101
106
  "metrics": metrics,
102
107
  "label": label,
103
108
  "label_name": label_name,
@@ -138,6 +143,8 @@ class LabeledMemoryLookup:
138
143
 
139
144
  updated_at = isoparse(d.pop("updated_at"))
140
145
 
146
+ edited_at = isoparse(d.pop("edited_at"))
147
+
141
148
  metrics = MemoryMetrics.from_dict(d.pop("metrics"))
142
149
 
143
150
  label = d.pop("label")
@@ -160,6 +167,7 @@ class LabeledMemoryLookup:
160
167
  memory_version=memory_version,
161
168
  created_at=created_at,
162
169
  updated_at=updated_at,
170
+ edited_at=edited_at,
163
171
  metrics=metrics,
164
172
  label=label,
165
173
  label_name=label_name,
@@ -40,6 +40,7 @@ class LabeledMemoryWithFeedbackMetrics:
40
40
  memory_version (int):
41
41
  created_at (datetime.datetime):
42
42
  updated_at (datetime.datetime):
43
+ edited_at (datetime.datetime):
43
44
  metrics (LabeledMemoryMetrics): Metrics computed for a labeled memory.
44
45
  label (int):
45
46
  label_name (Union[None, str]):
@@ -55,6 +56,7 @@ class LabeledMemoryWithFeedbackMetrics:
55
56
  memory_version: int
56
57
  created_at: datetime.datetime
57
58
  updated_at: datetime.datetime
59
+ edited_at: datetime.datetime
58
60
  metrics: "LabeledMemoryMetrics"
59
61
  label: int
60
62
  label_name: Union[None, str]
@@ -81,6 +83,8 @@ class LabeledMemoryWithFeedbackMetrics:
81
83
 
82
84
  updated_at = self.updated_at.isoformat()
83
85
 
86
+ edited_at = self.edited_at.isoformat()
87
+
84
88
  metrics = self.metrics.to_dict()
85
89
 
86
90
  label = self.label
@@ -104,6 +108,7 @@ class LabeledMemoryWithFeedbackMetrics:
104
108
  "memory_version": memory_version,
105
109
  "created_at": created_at,
106
110
  "updated_at": updated_at,
111
+ "edited_at": edited_at,
107
112
  "metrics": metrics,
108
113
  "label": label,
109
114
  "label_name": label_name,
@@ -148,6 +153,8 @@ class LabeledMemoryWithFeedbackMetrics:
148
153
 
149
154
  updated_at = isoparse(d.pop("updated_at"))
150
155
 
156
+ edited_at = isoparse(d.pop("edited_at"))
157
+
151
158
  metrics = LabeledMemoryMetrics.from_dict(d.pop("metrics"))
152
159
 
153
160
  label = d.pop("label")
@@ -172,6 +179,7 @@ class LabeledMemoryWithFeedbackMetrics:
172
179
  memory_version=memory_version,
173
180
  created_at=created_at,
174
181
  updated_at=updated_at,
182
+ edited_at=edited_at,
175
183
  metrics=metrics,
176
184
  label=label,
177
185
  label_name=label_name,
@@ -10,11 +10,13 @@ The main change is:
10
10
 
11
11
  # flake8: noqa: C901
12
12
 
13
+ import datetime
13
14
  from enum import Enum
14
15
  from typing import Any, List, Type, TypeVar, Union, cast
15
16
 
16
17
  from attrs import define as _attrs_define
17
18
  from attrs import field as _attrs_field
19
+ from dateutil.parser import isoparse
18
20
 
19
21
  from ..models.prediction_sort_item_item_type_0 import PredictionSortItemItemType0
20
22
  from ..models.prediction_sort_item_item_type_1 import PredictionSortItemItemType1
@@ -30,6 +32,8 @@ class ListPredictionsRequest:
30
32
  model_id (Union[None, Unset, str]):
31
33
  tag (Union[None, Unset, str]):
32
34
  prediction_ids (Union[List[str], None, Unset]):
35
+ start_timestamp (Union[None, Unset, datetime.datetime]):
36
+ end_timestamp (Union[None, Unset, datetime.datetime]):
33
37
  limit (Union[None, Unset, int]):
34
38
  offset (Union[None, Unset, int]): Default: 0.
35
39
  sort (Union[Unset, List[List[Union[PredictionSortItemItemType0, PredictionSortItemItemType1]]]]):
@@ -39,6 +43,8 @@ class ListPredictionsRequest:
39
43
  model_id: Union[None, Unset, str] = UNSET
40
44
  tag: Union[None, Unset, str] = UNSET
41
45
  prediction_ids: Union[List[str], None, Unset] = UNSET
46
+ start_timestamp: Union[None, Unset, datetime.datetime] = UNSET
47
+ end_timestamp: Union[None, Unset, datetime.datetime] = UNSET
42
48
  limit: Union[None, Unset, int] = UNSET
43
49
  offset: Union[None, Unset, int] = 0
44
50
  sort: Union[Unset, List[List[Union[PredictionSortItemItemType0, PredictionSortItemItemType1]]]] = UNSET
@@ -67,6 +73,22 @@ class ListPredictionsRequest:
67
73
  else:
68
74
  prediction_ids = self.prediction_ids
69
75
 
76
+ start_timestamp: Union[None, Unset, str]
77
+ if isinstance(self.start_timestamp, Unset):
78
+ start_timestamp = UNSET
79
+ elif isinstance(self.start_timestamp, datetime.datetime):
80
+ start_timestamp = self.start_timestamp.isoformat()
81
+ else:
82
+ start_timestamp = self.start_timestamp
83
+
84
+ end_timestamp: Union[None, Unset, str]
85
+ if isinstance(self.end_timestamp, Unset):
86
+ end_timestamp = UNSET
87
+ elif isinstance(self.end_timestamp, datetime.datetime):
88
+ end_timestamp = self.end_timestamp.isoformat()
89
+ else:
90
+ end_timestamp = self.end_timestamp
91
+
70
92
  limit: Union[None, Unset, int]
71
93
  if isinstance(self.limit, Unset):
72
94
  limit = UNSET
@@ -118,6 +140,10 @@ class ListPredictionsRequest:
118
140
  field_dict["tag"] = tag
119
141
  if prediction_ids is not UNSET:
120
142
  field_dict["prediction_ids"] = prediction_ids
143
+ if start_timestamp is not UNSET:
144
+ field_dict["start_timestamp"] = start_timestamp
145
+ if end_timestamp is not UNSET:
146
+ field_dict["end_timestamp"] = end_timestamp
121
147
  if limit is not UNSET:
122
148
  field_dict["limit"] = limit
123
149
  if offset is not UNSET:
@@ -168,6 +194,40 @@ class ListPredictionsRequest:
168
194
 
169
195
  prediction_ids = _parse_prediction_ids(d.pop("prediction_ids", UNSET))
170
196
 
197
+ def _parse_start_timestamp(data: object) -> Union[None, Unset, datetime.datetime]:
198
+ if data is None:
199
+ return data
200
+ if isinstance(data, Unset):
201
+ return data
202
+ try:
203
+ if not isinstance(data, str):
204
+ raise TypeError()
205
+ start_timestamp_type_0 = isoparse(data)
206
+
207
+ return start_timestamp_type_0
208
+ except: # noqa: E722
209
+ pass
210
+ return cast(Union[None, Unset, datetime.datetime], data)
211
+
212
+ start_timestamp = _parse_start_timestamp(d.pop("start_timestamp", UNSET))
213
+
214
+ def _parse_end_timestamp(data: object) -> Union[None, Unset, datetime.datetime]:
215
+ if data is None:
216
+ return data
217
+ if isinstance(data, Unset):
218
+ return data
219
+ try:
220
+ if not isinstance(data, str):
221
+ raise TypeError()
222
+ end_timestamp_type_0 = isoparse(data)
223
+
224
+ return end_timestamp_type_0
225
+ except: # noqa: E722
226
+ pass
227
+ return cast(Union[None, Unset, datetime.datetime], data)
228
+
229
+ end_timestamp = _parse_end_timestamp(d.pop("end_timestamp", UNSET))
230
+
171
231
  def _parse_limit(data: object) -> Union[None, Unset, int]:
172
232
  if data is None:
173
233
  return data
@@ -231,6 +291,8 @@ class ListPredictionsRequest:
231
291
  model_id=model_id,
232
292
  tag=tag,
233
293
  prediction_ids=prediction_ids,
294
+ start_timestamp=start_timestamp,
295
+ end_timestamp=end_timestamp,
234
296
  limit=limit,
235
297
  offset=offset,
236
298
  sort=sort,
@@ -13,7 +13,6 @@ The main change is:
13
13
  from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast
14
14
 
15
15
  from attrs import define as _attrs_define
16
- from attrs import field as _attrs_field
17
16
 
18
17
  from ..types import UNSET, Unset
19
18
 
@@ -44,7 +43,6 @@ class MemorysetAnalysisConfigs:
44
43
  duplicate: Union["MemorysetDuplicateAnalysisConfig", None, Unset] = UNSET
45
44
  projection: Union["MemorysetProjectionAnalysisConfig", None, Unset] = UNSET
46
45
  cluster: Union["MemorysetClusterAnalysisConfig", None, Unset] = UNSET
47
- additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
48
46
 
49
47
  def to_dict(self) -> dict[str, Any]:
50
48
  from ..models.memoryset_cluster_analysis_config import MemorysetClusterAnalysisConfig
@@ -94,7 +92,6 @@ class MemorysetAnalysisConfigs:
94
92
  cluster = self.cluster
95
93
 
96
94
  field_dict: dict[str, Any] = {}
97
- field_dict.update(self.additional_properties)
98
95
  field_dict.update({})
99
96
  if neighbor is not UNSET:
100
97
  field_dict["neighbor"] = neighbor
@@ -212,21 +209,4 @@ class MemorysetAnalysisConfigs:
212
209
  cluster=cluster,
213
210
  )
214
211
 
215
- memoryset_analysis_configs.additional_properties = d
216
212
  return memoryset_analysis_configs
217
-
218
- @property
219
- def additional_keys(self) -> list[str]:
220
- return list(self.additional_properties.keys())
221
-
222
- def __getitem__(self, key: str) -> Any:
223
- return self.additional_properties[key]
224
-
225
- def __setitem__(self, key: str, value: Any) -> None:
226
- self.additional_properties[key] = value
227
-
228
- def __delitem__(self, key: str) -> None:
229
- del self.additional_properties[key]
230
-
231
- def __contains__(self, key: str) -> bool:
232
- return key in self.additional_properties
@@ -28,14 +28,16 @@ class PredictionRequest:
28
28
  expected_labels (Union[List[int], None, Unset]):
29
29
  tags (Union[Unset, List[str]]):
30
30
  memoryset_override_id (Union[None, Unset, str]):
31
- disable_telemetry (Union[Unset, bool]): Default: False.
31
+ save_telemetry (Union[Unset, bool]): Default: True.
32
+ save_telemetry_synchronously (Union[Unset, bool]): Default: False.
32
33
  """
33
34
 
34
35
  input_values: List[str]
35
36
  expected_labels: Union[List[int], None, Unset] = UNSET
36
37
  tags: Union[Unset, List[str]] = UNSET
37
38
  memoryset_override_id: Union[None, Unset, str] = UNSET
38
- disable_telemetry: Union[Unset, bool] = False
39
+ save_telemetry: Union[Unset, bool] = True
40
+ save_telemetry_synchronously: Union[Unset, bool] = False
39
41
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
40
42
 
41
43
  def to_dict(self) -> dict[str, Any]:
@@ -62,7 +64,9 @@ class PredictionRequest:
62
64
  else:
63
65
  memoryset_override_id = self.memoryset_override_id
64
66
 
65
- disable_telemetry = self.disable_telemetry
67
+ save_telemetry = self.save_telemetry
68
+
69
+ save_telemetry_synchronously = self.save_telemetry_synchronously
66
70
 
67
71
  field_dict: dict[str, Any] = {}
68
72
  field_dict.update(self.additional_properties)
@@ -77,8 +81,10 @@ class PredictionRequest:
77
81
  field_dict["tags"] = tags
78
82
  if memoryset_override_id is not UNSET:
79
83
  field_dict["memoryset_override_id"] = memoryset_override_id
80
- if disable_telemetry is not UNSET:
81
- field_dict["disable_telemetry"] = disable_telemetry
84
+ if save_telemetry is not UNSET:
85
+ field_dict["save_telemetry"] = save_telemetry
86
+ if save_telemetry_synchronously is not UNSET:
87
+ field_dict["save_telemetry_synchronously"] = save_telemetry_synchronously
82
88
 
83
89
  return field_dict
84
90
 
@@ -156,14 +162,17 @@ class PredictionRequest:
156
162
 
157
163
  memoryset_override_id = _parse_memoryset_override_id(d.pop("memoryset_override_id", UNSET))
158
164
 
159
- disable_telemetry = d.pop("disable_telemetry", UNSET)
165
+ save_telemetry = d.pop("save_telemetry", UNSET)
166
+
167
+ save_telemetry_synchronously = d.pop("save_telemetry_synchronously", UNSET)
160
168
 
161
169
  prediction_request = cls(
162
170
  input_values=input_values,
163
171
  expected_labels=expected_labels,
164
172
  tags=tags,
165
173
  memoryset_override_id=memoryset_override_id,
166
- disable_telemetry=disable_telemetry,
174
+ save_telemetry=save_telemetry,
175
+ save_telemetry_synchronously=save_telemetry_synchronously,
167
176
  )
168
177
 
169
178
  prediction_request.additional_properties = d
@@ -2,11 +2,16 @@ from enum import Enum
2
2
 
3
3
 
4
4
  class PretrainedEmbeddingModelName(str, Enum):
5
+ BGE_BASE = "BGE_BASE"
5
6
  CDE_SMALL = "CDE_SMALL"
6
7
  CLIP_BASE = "CLIP_BASE"
7
8
  DISTILBERT = "DISTILBERT"
9
+ E5_LARGE = "E5_LARGE"
10
+ GIST_LARGE = "GIST_LARGE"
8
11
  GTE_BASE = "GTE_BASE"
9
12
  GTE_SMALL = "GTE_SMALL"
13
+ MXBAI_LARGE = "MXBAI_LARGE"
14
+ QWEN2_1_5B = "QWEN2_1_5B"
10
15
 
11
16
  def __str__(self) -> str:
12
17
  return str(self.value)
@@ -0,0 +1,99 @@
1
+ """
2
+ This file is generated by the openapi-python-client tool via the generate_api_client.py script
3
+
4
+ It is a customized template from the openapi-python-client tool's default template:
5
+ https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
6
+
7
+ The main change is:
8
+ - Fix typing issues
9
+ """
10
+
11
+ # flake8: noqa: C901
12
+
13
+ from typing import Any, List, Type, TypeVar, Union, cast
14
+
15
+ from attrs import define as _attrs_define
16
+ from attrs import field as _attrs_field
17
+
18
+ T = TypeVar("T", bound="ValidationError")
19
+
20
+
21
+ @_attrs_define
22
+ class ValidationError:
23
+ """
24
+ Attributes:
25
+ loc (List[Union[int, str]]):
26
+ msg (str):
27
+ type (str):
28
+ """
29
+
30
+ loc: List[Union[int, str]]
31
+ msg: str
32
+ type: str
33
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
34
+
35
+ def to_dict(self) -> dict[str, Any]:
36
+ loc = []
37
+ for loc_item_data in self.loc:
38
+ loc_item: Union[int, str]
39
+ loc_item = loc_item_data
40
+ loc.append(loc_item)
41
+
42
+ msg = self.msg
43
+
44
+ type = self.type
45
+
46
+ field_dict: dict[str, Any] = {}
47
+ field_dict.update(self.additional_properties)
48
+ field_dict.update(
49
+ {
50
+ "loc": loc,
51
+ "msg": msg,
52
+ "type": type,
53
+ }
54
+ )
55
+
56
+ return field_dict
57
+
58
+ @classmethod
59
+ def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
60
+ d = src_dict.copy()
61
+ loc = []
62
+ _loc = d.pop("loc")
63
+ for loc_item_data in _loc:
64
+
65
+ def _parse_loc_item(data: object) -> Union[int, str]:
66
+ return cast(Union[int, str], data)
67
+
68
+ loc_item = _parse_loc_item(loc_item_data)
69
+
70
+ loc.append(loc_item)
71
+
72
+ msg = d.pop("msg")
73
+
74
+ type = d.pop("type")
75
+
76
+ validation_error = cls(
77
+ loc=loc,
78
+ msg=msg,
79
+ type=type,
80
+ )
81
+
82
+ validation_error.additional_properties = d
83
+ return validation_error
84
+
85
+ @property
86
+ def additional_keys(self) -> list[str]:
87
+ return list(self.additional_properties.keys())
88
+
89
+ def __getitem__(self, key: str) -> Any:
90
+ return self.additional_properties[key]
91
+
92
+ def __setitem__(self, key: str, value: Any) -> None:
93
+ self.additional_properties[key] = value
94
+
95
+ def __delitem__(self, key: str) -> None:
96
+ del self.additional_properties[key]
97
+
98
+ def __contains__(self, key: str) -> bool:
99
+ return key in self.additional_properties
@@ -1,12 +1,16 @@
1
+ import logging
1
2
  import pickle
2
3
  from dataclasses import asdict, is_dataclass
3
4
  from os import PathLike
5
+ from tempfile import TemporaryDirectory
4
6
  from typing import Any, cast
5
7
 
6
8
  from datasets import Dataset
7
9
  from torch.utils.data import DataLoader as TorchDataLoader
8
10
  from torch.utils.data import Dataset as TorchDataset
9
11
 
12
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
13
+
10
14
 
11
15
  def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
12
16
  if isinstance(item, dict):
@@ -40,7 +44,24 @@ def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]
40
44
  return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
41
45
 
42
46
 
43
- def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None) -> Dataset:
47
+ def hf_dataset_from_torch(
48
+ torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None, ignore_cache=False
49
+ ) -> Dataset:
50
+ """
51
+ Create a HuggingFace Dataset from a PyTorch DataLoader or Dataset.
52
+
53
+ NOTE: It's important to ignore the cached files when testing (i.e., ignore_cache=Ture), because
54
+ cached results can ignore changes you've made to tests. This can make a test appear to succeed
55
+ when it's actually broken or vice versa.
56
+
57
+ Params:
58
+ torch_data: A PyTorch DataLoader or Dataset object to create the HuggingFace Dataset from.
59
+ column_names: Optional list of column names to use for the dataset. If not provided,
60
+ the column names will be inferred from the data.
61
+ ignore_cache: If True, the dataset will not be cached on disk.
62
+ Returns:
63
+ A HuggingFace Dataset object containing the data from the PyTorch DataLoader or Dataset.
64
+ """
44
65
  if isinstance(torch_data, TorchDataLoader):
45
66
  dataloader = torch_data
46
67
  else:
@@ -50,7 +71,15 @@ def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_nam
50
71
  for batch in dataloader:
51
72
  yield from parse_batch(batch, column_names=column_names)
52
73
 
53
- return cast(Dataset, Dataset.from_generator(generator))
74
+ if ignore_cache:
75
+ with TemporaryDirectory() as temp_dir:
76
+ ds = Dataset.from_generator(generator, cache_dir=temp_dir)
77
+ else:
78
+ ds = Dataset.from_generator(generator)
79
+
80
+ if not isinstance(ds, Dataset):
81
+ raise ValueError(f"Failed to create dataset from generator: {type(ds)}")
82
+ return ds
54
83
 
55
84
 
56
85
  def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset: