superb-ai-onprem 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of superb-ai-onprem might be problematic. Click here for more details.
- spb_onprem/__init__.py +39 -4
- spb_onprem/_version.py +2 -2
- spb_onprem/data/entities/__init__.py +2 -0
- spb_onprem/data/entities/data.py +2 -0
- spb_onprem/data/entities/data_annotation_stats.py +8 -0
- spb_onprem/data/params/data_list.py +7 -2
- spb_onprem/data/params/update_data.py +11 -0
- spb_onprem/data/params/update_data_slice.py +14 -2
- spb_onprem/data/queries.py +12 -1
- spb_onprem/data/service.py +9 -0
- spb_onprem/entities.py +24 -2
- spb_onprem/models/__init__.py +8 -3
- spb_onprem/models/entities/__init__.py +9 -0
- spb_onprem/models/entities/model.py +32 -0
- spb_onprem/models/entities/model_page_info.py +14 -0
- spb_onprem/models/entities/model_train_class.py +15 -0
- spb_onprem/models/params/__init__.py +16 -4
- spb_onprem/models/params/create_model.py +70 -0
- spb_onprem/models/params/delete_model.py +11 -8
- spb_onprem/models/params/model.py +17 -0
- spb_onprem/models/params/models.py +60 -0
- spb_onprem/models/params/pin_model.py +17 -0
- spb_onprem/models/params/unpin_model.py +17 -0
- spb_onprem/models/params/update_model.py +61 -0
- spb_onprem/models/queries.py +224 -19
- spb_onprem/models/service.py +251 -30
- spb_onprem/reports/__init__.py +25 -0
- spb_onprem/reports/entities/__init__.py +10 -0
- spb_onprem/reports/entities/analytics_report.py +22 -0
- spb_onprem/reports/entities/analytics_report_item.py +30 -0
- spb_onprem/reports/entities/analytics_report_page_info.py +14 -0
- spb_onprem/reports/params/__init__.py +29 -0
- spb_onprem/reports/params/analytics_report.py +17 -0
- spb_onprem/reports/params/analytics_reports.py +87 -0
- spb_onprem/reports/params/create_analytics_report.py +35 -0
- spb_onprem/reports/params/create_analytics_report_item.py +47 -0
- spb_onprem/reports/params/delete_analytics_report.py +17 -0
- spb_onprem/reports/params/delete_analytics_report_item.py +20 -0
- spb_onprem/reports/params/update_analytics_report.py +38 -0
- spb_onprem/reports/params/update_analytics_report_item.py +46 -0
- spb_onprem/reports/queries.py +239 -0
- spb_onprem/reports/service.py +328 -0
- spb_onprem/searches.py +18 -0
- {superb_ai_onprem-0.9.1.dist-info → superb_ai_onprem-0.10.0.dist-info}/METADATA +53 -9
- {superb_ai_onprem-0.9.1.dist-info → superb_ai_onprem-0.10.0.dist-info}/RECORD +48 -38
- spb_onprem/models/entities.py +0 -9
- spb_onprem/models/params/get_models.py +0 -29
- spb_onprem/predictions/__init__.py +0 -7
- spb_onprem/predictions/entities.py +0 -11
- spb_onprem/predictions/params/__init__.py +0 -15
- spb_onprem/predictions/params/create_prediction_set.py +0 -44
- spb_onprem/predictions/params/delete_prediction_from_data.py +0 -20
- spb_onprem/predictions/params/delete_prediction_set.py +0 -14
- spb_onprem/predictions/params/get_prediction_set.py +0 -14
- spb_onprem/predictions/params/get_prediction_sets.py +0 -29
- spb_onprem/predictions/params/update_prediction_set_data_info.py +0 -28
- spb_onprem/predictions/queries.py +0 -110
- spb_onprem/predictions/service.py +0 -225
- tests/models/__init__.py +0 -1
- tests/models/test_model_service.py +0 -249
- tests/predictions/__init__.py +0 -1
- tests/predictions/test_prediction_service.py +0 -359
- {superb_ai_onprem-0.9.1.dist-info → superb_ai_onprem-0.10.0.dist-info}/WHEEL +0 -0
- {superb_ai_onprem-0.9.1.dist-info → superb_ai_onprem-0.10.0.dist-info}/licenses/LICENSE +0 -0
- {superb_ai_onprem-0.9.1.dist-info → superb_ai_onprem-0.10.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
def pin_model_params(
|
|
2
|
+
dataset_id: str,
|
|
3
|
+
model_id: str,
|
|
4
|
+
):
|
|
5
|
+
"""Get parameters for pinning a model.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
dataset_id: The dataset ID
|
|
9
|
+
model_id: The model ID
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
dict: Parameters for pinning a model
|
|
13
|
+
"""
|
|
14
|
+
return {
|
|
15
|
+
"datasetId": dataset_id,
|
|
16
|
+
"id": model_id,
|
|
17
|
+
}
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
def unpin_model_params(
|
|
2
|
+
dataset_id: str,
|
|
3
|
+
model_id: str,
|
|
4
|
+
):
|
|
5
|
+
"""Get parameters for unpinning a model.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
dataset_id: The dataset ID
|
|
9
|
+
model_id: The model ID
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
dict: Parameters for unpinning a model
|
|
13
|
+
"""
|
|
14
|
+
return {
|
|
15
|
+
"datasetId": dataset_id,
|
|
16
|
+
"id": model_id,
|
|
17
|
+
}
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import List, Union, Any
|
|
2
|
+
from spb_onprem.base_types import Undefined, UndefinedType
|
|
3
|
+
from spb_onprem.models.entities.model_train_class import ModelTrainClass
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def update_model_params(
|
|
7
|
+
dataset_id: str,
|
|
8
|
+
model_id: str,
|
|
9
|
+
name: Union[str, UndefinedType] = Undefined,
|
|
10
|
+
description: Union[str, UndefinedType] = Undefined,
|
|
11
|
+
training_classes: Union[List[ModelTrainClass], UndefinedType] = Undefined,
|
|
12
|
+
model_content_id: Union[str, UndefinedType] = Undefined,
|
|
13
|
+
is_trained: Union[bool, UndefinedType] = Undefined,
|
|
14
|
+
trained_at: Union[str, UndefinedType] = Undefined,
|
|
15
|
+
meta: Union[Any, UndefinedType] = Undefined,
|
|
16
|
+
):
|
|
17
|
+
"""Get parameters for updating a model.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
dataset_id: The dataset ID
|
|
21
|
+
model_id: The model ID
|
|
22
|
+
name: Optional new name
|
|
23
|
+
description: Optional new description
|
|
24
|
+
training_classes: Optional new training classes
|
|
25
|
+
model_content_id: Optional new model content ID
|
|
26
|
+
is_trained: Optional new trained status
|
|
27
|
+
trained_at: Optional new trained timestamp
|
|
28
|
+
meta: Optional new metadata
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
dict: Parameters for updating a model
|
|
32
|
+
"""
|
|
33
|
+
params = {
|
|
34
|
+
"datasetId": dataset_id,
|
|
35
|
+
"id": model_id,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
if not isinstance(name, UndefinedType):
|
|
39
|
+
params["name"] = name
|
|
40
|
+
|
|
41
|
+
if not isinstance(description, UndefinedType):
|
|
42
|
+
params["description"] = description
|
|
43
|
+
|
|
44
|
+
if not isinstance(training_classes, UndefinedType):
|
|
45
|
+
params["trainingClasses"] = [
|
|
46
|
+
tc.model_dump(by_alias=True, exclude_unset=True) for tc in training_classes
|
|
47
|
+
] if training_classes is not None else None
|
|
48
|
+
|
|
49
|
+
if not isinstance(model_content_id, UndefinedType):
|
|
50
|
+
params["modelContentId"] = model_content_id
|
|
51
|
+
|
|
52
|
+
if not isinstance(is_trained, UndefinedType):
|
|
53
|
+
params["isTrained"] = is_trained
|
|
54
|
+
|
|
55
|
+
if not isinstance(trained_at, UndefinedType):
|
|
56
|
+
params["trainedAt"] = trained_at
|
|
57
|
+
|
|
58
|
+
if not isinstance(meta, UndefinedType):
|
|
59
|
+
params["meta"] = meta
|
|
60
|
+
|
|
61
|
+
return params
|
spb_onprem/models/queries.py
CHANGED
|
@@ -1,33 +1,238 @@
|
|
|
1
|
-
from .params import (
|
|
2
|
-
|
|
1
|
+
from spb_onprem.models.params import (
|
|
2
|
+
models_params,
|
|
3
|
+
model_params,
|
|
4
|
+
create_model_params,
|
|
5
|
+
update_model_params,
|
|
6
|
+
pin_model_params,
|
|
7
|
+
unpin_model_params,
|
|
3
8
|
delete_model_params,
|
|
4
9
|
)
|
|
5
10
|
|
|
6
11
|
|
|
7
|
-
class
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
12
|
+
class Schemas:
|
|
13
|
+
MODEL_TRAIN_CLASS = '''
|
|
14
|
+
class
|
|
15
|
+
annotationType
|
|
16
|
+
ap
|
|
17
|
+
trainingAnnotationsCount
|
|
18
|
+
validationAnnotationsCount
|
|
19
|
+
'''
|
|
20
|
+
|
|
21
|
+
MODEL = '''
|
|
22
|
+
id
|
|
23
|
+
datasetId
|
|
24
|
+
baselineModel
|
|
25
|
+
name
|
|
26
|
+
description
|
|
27
|
+
trainingClasses {
|
|
28
|
+
class
|
|
29
|
+
annotationType
|
|
30
|
+
ap
|
|
31
|
+
trainingAnnotationsCount
|
|
32
|
+
validationAnnotationsCount
|
|
33
|
+
}
|
|
34
|
+
trainingDataCount
|
|
35
|
+
validationDataCount
|
|
36
|
+
isPinned
|
|
37
|
+
isTrained
|
|
38
|
+
trainedAt
|
|
39
|
+
modelContent {
|
|
40
|
+
id
|
|
41
|
+
downloadURL
|
|
42
|
+
}
|
|
43
|
+
createdBy
|
|
44
|
+
createdAt
|
|
45
|
+
updatedBy
|
|
46
|
+
updatedAt
|
|
47
|
+
meta
|
|
48
|
+
trainingSlices {
|
|
49
|
+
id
|
|
50
|
+
datasetId
|
|
51
|
+
name
|
|
52
|
+
description
|
|
53
|
+
isPinned
|
|
54
|
+
createdAt
|
|
55
|
+
createdBy
|
|
56
|
+
updatedAt
|
|
57
|
+
updatedBy
|
|
58
|
+
}
|
|
59
|
+
validationSlices {
|
|
60
|
+
id
|
|
61
|
+
datasetId
|
|
62
|
+
name
|
|
63
|
+
description
|
|
64
|
+
isPinned
|
|
65
|
+
createdAt
|
|
66
|
+
createdBy
|
|
67
|
+
updatedAt
|
|
68
|
+
updatedBy
|
|
69
|
+
}
|
|
70
|
+
'''
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Queries():
|
|
74
|
+
MODELS = {
|
|
75
|
+
"name": "models",
|
|
76
|
+
"query": f'''
|
|
77
|
+
query Models(
|
|
78
|
+
$datasetId: ID!,
|
|
79
|
+
$filter: ModelFilter,
|
|
80
|
+
$cursor: String,
|
|
81
|
+
$length: Int
|
|
82
|
+
) {{
|
|
83
|
+
models(
|
|
84
|
+
datasetId: $datasetId,
|
|
85
|
+
filter: $filter,
|
|
86
|
+
cursor: $cursor,
|
|
87
|
+
length: $length
|
|
88
|
+
) {{
|
|
89
|
+
models {{
|
|
90
|
+
{Schemas.MODEL}
|
|
91
|
+
}}
|
|
17
92
|
next
|
|
18
93
|
totalCount
|
|
19
|
-
}
|
|
20
|
-
}
|
|
94
|
+
}}
|
|
95
|
+
}}
|
|
96
|
+
''',
|
|
97
|
+
"variables": models_params,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
MODEL = {
|
|
101
|
+
"name": "model",
|
|
102
|
+
"query": f'''
|
|
103
|
+
query Model(
|
|
104
|
+
$datasetId: ID!,
|
|
105
|
+
$modelId: ID!
|
|
106
|
+
) {{
|
|
107
|
+
model(
|
|
108
|
+
datasetId: $datasetId,
|
|
109
|
+
id: $modelId
|
|
110
|
+
) {{
|
|
111
|
+
{Schemas.MODEL}
|
|
112
|
+
}}
|
|
113
|
+
}}
|
|
114
|
+
''',
|
|
115
|
+
"variables": model_params,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
CREATE_MODEL = {
|
|
119
|
+
"name": "createModel",
|
|
120
|
+
"query": f'''
|
|
121
|
+
mutation CreateModel(
|
|
122
|
+
$datasetId: ID!,
|
|
123
|
+
$name: String!,
|
|
124
|
+
$description: String,
|
|
125
|
+
$baselineModel: String!,
|
|
126
|
+
$trainingClasses: [ModelTrainClassInput!],
|
|
127
|
+
$trainingSliceIds: [ID!]!,
|
|
128
|
+
$validationSliceIds: [ID!]!,
|
|
129
|
+
$modelContentId: String,
|
|
130
|
+
$isTrained: Boolean,
|
|
131
|
+
$trainedAt: DateTime,
|
|
132
|
+
$isPinned: Boolean,
|
|
133
|
+
$meta: JSONObject
|
|
134
|
+
) {{
|
|
135
|
+
createModel(
|
|
136
|
+
datasetId: $datasetId,
|
|
137
|
+
name: $name,
|
|
138
|
+
description: $description,
|
|
139
|
+
baselineModel: $baselineModel,
|
|
140
|
+
trainingClasses: $trainingClasses,
|
|
141
|
+
trainingSliceIds: $trainingSliceIds,
|
|
142
|
+
validationSliceIds: $validationSliceIds,
|
|
143
|
+
modelContentId: $modelContentId,
|
|
144
|
+
isTrained: $isTrained,
|
|
145
|
+
trainedAt: $trainedAt,
|
|
146
|
+
isPinned: $isPinned,
|
|
147
|
+
meta: $meta
|
|
148
|
+
) {{
|
|
149
|
+
{Schemas.MODEL}
|
|
150
|
+
}}
|
|
151
|
+
}}
|
|
21
152
|
''',
|
|
22
|
-
"variables":
|
|
153
|
+
"variables": create_model_params,
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
UPDATE_MODEL = {
|
|
157
|
+
"name": "updateModel",
|
|
158
|
+
"query": f'''
|
|
159
|
+
mutation UpdateModel(
|
|
160
|
+
$datasetId: ID!,
|
|
161
|
+
$id: ID!,
|
|
162
|
+
$name: String,
|
|
163
|
+
$description: String,
|
|
164
|
+
$trainingClasses: [ModelTrainClassInput!],
|
|
165
|
+
$modelContentId: String,
|
|
166
|
+
$isTrained: Boolean,
|
|
167
|
+
$trainedAt: DateTime,
|
|
168
|
+
$meta: JSONObject
|
|
169
|
+
) {{
|
|
170
|
+
updateModel(
|
|
171
|
+
datasetId: $datasetId,
|
|
172
|
+
id: $id,
|
|
173
|
+
name: $name,
|
|
174
|
+
description: $description,
|
|
175
|
+
trainingClasses: $trainingClasses,
|
|
176
|
+
modelContentId: $modelContentId,
|
|
177
|
+
isTrained: $isTrained,
|
|
178
|
+
trainedAt: $trainedAt,
|
|
179
|
+
meta: $meta
|
|
180
|
+
) {{
|
|
181
|
+
{Schemas.MODEL}
|
|
182
|
+
}}
|
|
183
|
+
}}
|
|
184
|
+
''',
|
|
185
|
+
"variables": update_model_params,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
PIN_MODEL = {
|
|
189
|
+
"name": "pinModel",
|
|
190
|
+
"query": f'''
|
|
191
|
+
mutation PinModel(
|
|
192
|
+
$datasetId: ID!,
|
|
193
|
+
$id: ID!
|
|
194
|
+
) {{
|
|
195
|
+
pinModel(
|
|
196
|
+
datasetId: $datasetId,
|
|
197
|
+
id: $id
|
|
198
|
+
) {{
|
|
199
|
+
{Schemas.MODEL}
|
|
200
|
+
}}
|
|
201
|
+
}}
|
|
202
|
+
''',
|
|
203
|
+
"variables": pin_model_params,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
UNPIN_MODEL = {
|
|
207
|
+
"name": "unpinModel",
|
|
208
|
+
"query": f'''
|
|
209
|
+
mutation UnpinModel(
|
|
210
|
+
$datasetId: ID!,
|
|
211
|
+
$id: ID!
|
|
212
|
+
) {{
|
|
213
|
+
unpinModel(
|
|
214
|
+
datasetId: $datasetId,
|
|
215
|
+
id: $id
|
|
216
|
+
) {{
|
|
217
|
+
{Schemas.MODEL}
|
|
218
|
+
}}
|
|
219
|
+
}}
|
|
220
|
+
''',
|
|
221
|
+
"variables": unpin_model_params,
|
|
23
222
|
}
|
|
24
223
|
|
|
25
224
|
DELETE_MODEL = {
|
|
26
225
|
"name": "deleteModel",
|
|
27
226
|
"query": '''
|
|
28
|
-
mutation DeleteModel(
|
|
29
|
-
|
|
227
|
+
mutation DeleteModel(
|
|
228
|
+
$datasetId: ID!,
|
|
229
|
+
$id: ID!
|
|
230
|
+
) {
|
|
231
|
+
deleteModel(
|
|
232
|
+
datasetId: $datasetId,
|
|
233
|
+
id: $id
|
|
234
|
+
)
|
|
30
235
|
}
|
|
31
236
|
''',
|
|
32
|
-
"variables": delete_model_params
|
|
33
|
-
}
|
|
237
|
+
"variables": delete_model_params,
|
|
238
|
+
}
|
spb_onprem/models/service.py
CHANGED
|
@@ -1,68 +1,289 @@
|
|
|
1
|
-
from typing import Optional,
|
|
2
|
-
|
|
1
|
+
from typing import Optional, Union, List, Any
|
|
3
2
|
from spb_onprem.base_service import BaseService
|
|
4
|
-
from spb_onprem.base_types import Undefined, UndefinedType
|
|
5
3
|
from spb_onprem.exceptions import BadParameterError
|
|
4
|
+
from spb_onprem.base_types import Undefined, UndefinedType
|
|
6
5
|
from .queries import Queries
|
|
7
|
-
from .entities import Model
|
|
6
|
+
from .entities import Model, ModelPageInfo, ModelTrainClass
|
|
7
|
+
from .params.models import ModelsFilter
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class ModelService(BaseService):
|
|
11
|
-
"""
|
|
11
|
+
"""
|
|
12
|
+
Service class for handling model-related operations.
|
|
13
|
+
"""
|
|
12
14
|
|
|
13
15
|
def get_models(
|
|
14
16
|
self,
|
|
15
17
|
dataset_id: str,
|
|
16
|
-
|
|
17
|
-
cursor:
|
|
18
|
-
length: int =
|
|
19
|
-
)
|
|
20
|
-
"""
|
|
18
|
+
models_filter: Optional[ModelsFilter] = None,
|
|
19
|
+
cursor: Optional[str] = None,
|
|
20
|
+
length: Optional[int] = 10
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Get a list of models based on the provided filter and pagination parameters.
|
|
21
24
|
|
|
22
25
|
Args:
|
|
23
|
-
dataset_id (str): The dataset ID
|
|
24
|
-
|
|
25
|
-
cursor (
|
|
26
|
-
length (int): Number of items
|
|
26
|
+
dataset_id (str): The dataset ID
|
|
27
|
+
models_filter (Optional[ModelsFilter]): Filter criteria for models
|
|
28
|
+
cursor (Optional[str]): Cursor for pagination
|
|
29
|
+
length (Optional[int]): Number of items per page (default: 10)
|
|
27
30
|
|
|
28
31
|
Returns:
|
|
29
|
-
|
|
32
|
+
tuple: A tuple containing:
|
|
33
|
+
- List[Model]: A list of Model objects
|
|
34
|
+
- str: Next cursor for pagination
|
|
35
|
+
- int: Total count of models
|
|
30
36
|
"""
|
|
31
37
|
if dataset_id is None:
|
|
32
38
|
raise BadParameterError("dataset_id is required.")
|
|
33
|
-
|
|
39
|
+
|
|
40
|
+
if length > 50:
|
|
41
|
+
raise BadParameterError("The maximum length is 50.")
|
|
42
|
+
|
|
34
43
|
response = self.request_gql(
|
|
35
|
-
Queries.
|
|
36
|
-
Queries.
|
|
44
|
+
Queries.MODELS,
|
|
45
|
+
Queries.MODELS["variables"](
|
|
37
46
|
dataset_id=dataset_id,
|
|
38
|
-
|
|
47
|
+
models_filter=models_filter,
|
|
39
48
|
cursor=cursor,
|
|
40
49
|
length=length
|
|
41
50
|
)
|
|
42
51
|
)
|
|
43
|
-
|
|
52
|
+
|
|
53
|
+
page_info = ModelPageInfo.model_validate(response)
|
|
44
54
|
return (
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
55
|
+
page_info.models or [],
|
|
56
|
+
page_info.next,
|
|
57
|
+
page_info.total_count or 0
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def get_model(
|
|
61
|
+
self,
|
|
62
|
+
dataset_id: str,
|
|
63
|
+
model_id: str,
|
|
64
|
+
):
|
|
65
|
+
"""
|
|
66
|
+
Retrieve a model by its ID.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
dataset_id (str): The dataset ID
|
|
70
|
+
model_id (str): The ID of the model to retrieve
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Model: The retrieved model object
|
|
74
|
+
"""
|
|
75
|
+
if dataset_id is None:
|
|
76
|
+
raise BadParameterError("dataset_id is required.")
|
|
77
|
+
|
|
78
|
+
if model_id is None:
|
|
79
|
+
raise BadParameterError("model_id is required.")
|
|
80
|
+
|
|
81
|
+
response = self.request_gql(
|
|
82
|
+
Queries.MODEL,
|
|
83
|
+
Queries.MODEL["variables"](
|
|
84
|
+
dataset_id=dataset_id,
|
|
85
|
+
model_id=model_id
|
|
86
|
+
),
|
|
87
|
+
)
|
|
88
|
+
return Model.model_validate(response)
|
|
89
|
+
|
|
90
|
+
def create_model(
|
|
91
|
+
self,
|
|
92
|
+
dataset_id: str,
|
|
93
|
+
name: str,
|
|
94
|
+
baseline_model: str,
|
|
95
|
+
training_slice_ids: List[str],
|
|
96
|
+
validation_slice_ids: List[str],
|
|
97
|
+
description: Union[str, UndefinedType] = Undefined,
|
|
98
|
+
training_classes: Union[List[ModelTrainClass], UndefinedType] = Undefined,
|
|
99
|
+
model_content_id: Union[str, UndefinedType] = Undefined,
|
|
100
|
+
is_trained: Union[bool, UndefinedType] = Undefined,
|
|
101
|
+
trained_at: Union[str, UndefinedType] = Undefined,
|
|
102
|
+
is_pinned: Union[bool, UndefinedType] = Undefined,
|
|
103
|
+
meta: Union[Any, UndefinedType] = Undefined,
|
|
104
|
+
):
|
|
105
|
+
"""
|
|
106
|
+
Create a new model.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
dataset_id (str): The dataset ID
|
|
110
|
+
name (str): The model name
|
|
111
|
+
baseline_model (str): The baseline model used
|
|
112
|
+
training_slice_ids (List[str]): The IDs of the training slices
|
|
113
|
+
validation_slice_ids (List[str]): The IDs of the validation slices
|
|
114
|
+
description (Optional[str]): The description of the model
|
|
115
|
+
training_classes (Optional[List[ModelTrainClass]]): The training classes
|
|
116
|
+
model_content_id (Optional[str]): The model content ID
|
|
117
|
+
is_trained (Optional[bool]): Whether the model is trained
|
|
118
|
+
trained_at (Optional[str]): When the model was trained
|
|
119
|
+
is_pinned (Optional[bool]): Whether the model is pinned
|
|
120
|
+
meta (Optional[Any]): The metadata of the model
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Model: The created model object
|
|
124
|
+
"""
|
|
125
|
+
if dataset_id is None:
|
|
126
|
+
raise BadParameterError("dataset_id is required.")
|
|
127
|
+
|
|
128
|
+
if name is None:
|
|
129
|
+
raise BadParameterError("name is required.")
|
|
130
|
+
|
|
131
|
+
if baseline_model is None:
|
|
132
|
+
raise BadParameterError("baseline_model is required.")
|
|
133
|
+
|
|
134
|
+
if training_slice_ids is None:
|
|
135
|
+
raise BadParameterError("training_slice_ids is required.")
|
|
136
|
+
|
|
137
|
+
if validation_slice_ids is None:
|
|
138
|
+
raise BadParameterError("validation_slice_ids is required.")
|
|
139
|
+
|
|
140
|
+
response = self.request_gql(
|
|
141
|
+
Queries.CREATE_MODEL,
|
|
142
|
+
Queries.CREATE_MODEL["variables"](
|
|
143
|
+
dataset_id=dataset_id,
|
|
144
|
+
name=name,
|
|
145
|
+
baseline_model=baseline_model,
|
|
146
|
+
training_slice_ids=training_slice_ids,
|
|
147
|
+
validation_slice_ids=validation_slice_ids,
|
|
148
|
+
description=description,
|
|
149
|
+
training_classes=training_classes,
|
|
150
|
+
model_content_id=model_content_id,
|
|
151
|
+
is_trained=is_trained,
|
|
152
|
+
trained_at=trained_at,
|
|
153
|
+
is_pinned=is_pinned,
|
|
154
|
+
meta=meta,
|
|
155
|
+
),
|
|
156
|
+
)
|
|
157
|
+
return Model.model_validate(response)
|
|
158
|
+
|
|
159
|
+
def update_model(
|
|
160
|
+
self,
|
|
161
|
+
dataset_id: str,
|
|
162
|
+
model_id: str,
|
|
163
|
+
name: Union[str, UndefinedType] = Undefined,
|
|
164
|
+
description: Union[str, UndefinedType] = Undefined,
|
|
165
|
+
training_classes: Union[List[ModelTrainClass], UndefinedType] = Undefined,
|
|
166
|
+
model_content_id: Union[str, UndefinedType] = Undefined,
|
|
167
|
+
is_trained: Union[bool, UndefinedType] = Undefined,
|
|
168
|
+
trained_at: Union[str, UndefinedType] = Undefined,
|
|
169
|
+
meta: Union[Any, UndefinedType] = Undefined,
|
|
170
|
+
):
|
|
171
|
+
"""
|
|
172
|
+
Update a model.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
dataset_id (str): The dataset ID
|
|
176
|
+
model_id (str): The ID of the model to update
|
|
177
|
+
name (Optional[str]): The new name
|
|
178
|
+
description (Optional[str]): The new description
|
|
179
|
+
training_classes (Optional[List[ModelTrainClass]]): The new training classes
|
|
180
|
+
model_content_id (Optional[str]): The new model content ID
|
|
181
|
+
is_trained (Optional[bool]): The new trained status
|
|
182
|
+
trained_at (Optional[str]): The new trained timestamp
|
|
183
|
+
meta (Optional[Any]): The new metadata
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Model: The updated model object
|
|
187
|
+
"""
|
|
188
|
+
if dataset_id is None:
|
|
189
|
+
raise BadParameterError("dataset_id is required.")
|
|
190
|
+
|
|
191
|
+
if model_id is None:
|
|
192
|
+
raise BadParameterError("model_id is required.")
|
|
193
|
+
|
|
194
|
+
response = self.request_gql(
|
|
195
|
+
Queries.UPDATE_MODEL,
|
|
196
|
+
Queries.UPDATE_MODEL["variables"](
|
|
197
|
+
dataset_id=dataset_id,
|
|
198
|
+
model_id=model_id,
|
|
199
|
+
name=name,
|
|
200
|
+
description=description,
|
|
201
|
+
training_classes=training_classes,
|
|
202
|
+
model_content_id=model_content_id,
|
|
203
|
+
is_trained=is_trained,
|
|
204
|
+
trained_at=trained_at,
|
|
205
|
+
meta=meta,
|
|
206
|
+
),
|
|
48
207
|
)
|
|
208
|
+
return Model.model_validate(response)
|
|
209
|
+
|
|
210
|
+
def pin_model(
|
|
211
|
+
self,
|
|
212
|
+
dataset_id: str,
|
|
213
|
+
model_id: str,
|
|
214
|
+
):
|
|
215
|
+
"""
|
|
216
|
+
Pin a model.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
dataset_id (str): The dataset ID
|
|
220
|
+
model_id (str): The ID of the model to pin
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Model: The pinned model object
|
|
224
|
+
"""
|
|
225
|
+
if dataset_id is None:
|
|
226
|
+
raise BadParameterError("dataset_id is required.")
|
|
227
|
+
|
|
228
|
+
if model_id is None:
|
|
229
|
+
raise BadParameterError("model_id is required.")
|
|
230
|
+
|
|
231
|
+
response = self.request_gql(
|
|
232
|
+
Queries.PIN_MODEL,
|
|
233
|
+
Queries.PIN_MODEL["variables"](
|
|
234
|
+
dataset_id=dataset_id,
|
|
235
|
+
model_id=model_id,
|
|
236
|
+
),
|
|
237
|
+
)
|
|
238
|
+
return Model.model_validate(response)
|
|
239
|
+
|
|
240
|
+
def unpin_model(
|
|
241
|
+
self,
|
|
242
|
+
dataset_id: str,
|
|
243
|
+
model_id: str,
|
|
244
|
+
):
|
|
245
|
+
"""
|
|
246
|
+
Unpin a model.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
dataset_id (str): The dataset ID
|
|
250
|
+
model_id (str): The ID of the model to unpin
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Model: The unpinned model object
|
|
254
|
+
"""
|
|
255
|
+
if dataset_id is None:
|
|
256
|
+
raise BadParameterError("dataset_id is required.")
|
|
257
|
+
|
|
258
|
+
if model_id is None:
|
|
259
|
+
raise BadParameterError("model_id is required.")
|
|
260
|
+
|
|
261
|
+
response = self.request_gql(
|
|
262
|
+
Queries.UNPIN_MODEL,
|
|
263
|
+
Queries.UNPIN_MODEL["variables"](
|
|
264
|
+
dataset_id=dataset_id,
|
|
265
|
+
model_id=model_id,
|
|
266
|
+
),
|
|
267
|
+
)
|
|
268
|
+
return Model.model_validate(response)
|
|
49
269
|
|
|
50
270
|
def delete_model(
|
|
51
271
|
self,
|
|
52
272
|
dataset_id: str,
|
|
53
|
-
model_id: str
|
|
273
|
+
model_id: str,
|
|
54
274
|
) -> bool:
|
|
55
|
-
"""Delete a model
|
|
275
|
+
"""Delete a model.
|
|
56
276
|
|
|
57
277
|
Args:
|
|
58
|
-
dataset_id (str): The dataset ID
|
|
59
|
-
model_id (str): The model
|
|
278
|
+
dataset_id (str): The dataset ID
|
|
279
|
+
model_id (str): The ID of the model to delete
|
|
60
280
|
|
|
61
281
|
Returns:
|
|
62
|
-
bool: True if deletion was successful
|
|
282
|
+
bool: True if deletion was successful
|
|
63
283
|
"""
|
|
64
284
|
if dataset_id is None:
|
|
65
285
|
raise BadParameterError("dataset_id is required.")
|
|
286
|
+
|
|
66
287
|
if model_id is None:
|
|
67
288
|
raise BadParameterError("model_id is required.")
|
|
68
289
|
|
|
@@ -70,7 +291,7 @@ class ModelService(BaseService):
|
|
|
70
291
|
Queries.DELETE_MODEL,
|
|
71
292
|
Queries.DELETE_MODEL["variables"](
|
|
72
293
|
dataset_id=dataset_id,
|
|
73
|
-
model_id=model_id
|
|
294
|
+
model_id=model_id,
|
|
74
295
|
)
|
|
75
296
|
)
|
|
76
|
-
return response
|
|
297
|
+
return response
|