adaptive-sdk 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_sdk/client.py +2 -0
- adaptive_sdk/graphql_client/__init__.py +6 -4
- adaptive_sdk/graphql_client/add_model_to_use_case.py +6 -0
- adaptive_sdk/graphql_client/async_client.py +35 -19
- adaptive_sdk/graphql_client/client.py +35 -19
- adaptive_sdk/graphql_client/custom_fields.py +20 -4
- adaptive_sdk/graphql_client/custom_mutations.py +29 -8
- adaptive_sdk/graphql_client/deploy_model.py +12 -0
- adaptive_sdk/graphql_client/enums.py +3 -20
- adaptive_sdk/graphql_client/fragments.py +4 -4
- adaptive_sdk/graphql_client/input_types.py +157 -18
- adaptive_sdk/graphql_client/remove_model_from_use_case.py +6 -0
- adaptive_sdk/input_types/typed_dicts.py +14 -15
- adaptive_sdk/resources/__init__.py +3 -0
- adaptive_sdk/resources/artifacts.py +61 -0
- adaptive_sdk/resources/chat.py +11 -9
- adaptive_sdk/resources/interactions.py +57 -25
- adaptive_sdk/resources/models.py +86 -95
- adaptive_sdk/rest/rest_types.py +2 -1
- {adaptive_sdk-0.1.13.dist-info → adaptive_sdk-0.1.14.dist-info}/METADATA +4 -1
- {adaptive_sdk-0.1.13.dist-info → adaptive_sdk-0.1.14.dist-info}/RECORD +22 -19
- adaptive_sdk/graphql_client/attach_model_to_use_case.py +0 -12
- {adaptive_sdk-0.1.13.dist-info → adaptive_sdk-0.1.14.dist-info}/WHEEL +0 -0
|
@@ -3,8 +3,8 @@ from graphql import DocumentNode, NamedTypeNode, NameNode, OperationDefinitionNo
|
|
|
3
3
|
from .add_external_model import AddExternalModel
|
|
4
4
|
from .add_hf_model import AddHFModel
|
|
5
5
|
from .add_model import AddModel
|
|
6
|
+
from .add_model_to_use_case import AddModelToUseCase
|
|
6
7
|
from .add_remote_env import AddRemoteEnv
|
|
7
|
-
from .attach_model_to_use_case import AttachModelToUseCase
|
|
8
8
|
from .base_client_open_telemetry import BaseClientOpenTelemetry
|
|
9
9
|
from .base_model import UNSET, UnsetType, Upload
|
|
10
10
|
from .base_operation import GraphQLField
|
|
@@ -28,6 +28,7 @@ from .delete_dataset import DeleteDataset
|
|
|
28
28
|
from .delete_grader import DeleteGrader
|
|
29
29
|
from .delete_judge import DeleteJudge
|
|
30
30
|
from .delete_user import DeleteUser
|
|
31
|
+
from .deploy_model import DeployModel
|
|
31
32
|
from .describe_ab_campaign import DescribeAbCampaign
|
|
32
33
|
from .describe_dataset import DescribeDataset
|
|
33
34
|
from .describe_interaction import DescribeInteraction
|
|
@@ -41,7 +42,7 @@ from .enums import CompletionGroupBy
|
|
|
41
42
|
from .get_custom_recipe import GetCustomRecipe
|
|
42
43
|
from .get_grader import GetGrader
|
|
43
44
|
from .get_judge import GetJudge
|
|
44
|
-
from .input_types import AbcampaignCreate, AbCampaignFilter, AddExternalModelInput, AddHFModelInput, AddModelInput,
|
|
45
|
+
from .input_types import AbcampaignCreate, AbCampaignFilter, AddExternalModelInput, AddHFModelInput, AddModelInput, AddModelToUseCaseInput, CreateRecipeInput, CursorPageInput, CustomRecipeFilterInput, DatasetCreate, DatasetCreateFromMultipartUpload, DatasetUploadProcessingStatusInput, DeployModelInput, GraderCreateInput, GraderUpdateInput, JobInput, JudgeCreate, JudgeUpdate, ListCompletionsFilterInput, ListJobsFilterInput, MetricCreate, MetricLink, MetricUnlink, ModelComputeConfigInput, ModelFilter, OrderPair, PrebuiltJudgeCreate, RemoteEnvCreate, RemoveModelFromUseCaseInput, ResizePartitionInput, RoleCreate, TeamCreate, TeamMemberRemove, TeamMemberSet, UpdateModelService, UpdateRecipeInput, UseCaseCreate, UseCaseShares, UserCreate
|
|
45
46
|
from .link_metric import LinkMetric
|
|
46
47
|
from .list_ab_campaigns import ListAbCampaigns
|
|
47
48
|
from .list_compute_pools import ListComputePools
|
|
@@ -65,6 +66,7 @@ from .list_users import ListUsers
|
|
|
65
66
|
from .load_dataset import LoadDataset
|
|
66
67
|
from .lock_grader import LockGrader
|
|
67
68
|
from .me import Me
|
|
69
|
+
from .remove_model_from_use_case import RemoveModelFromUseCase
|
|
68
70
|
from .remove_remote_env import RemoveRemoteEnv
|
|
69
71
|
from .remove_team_member import RemoveTeamMember
|
|
70
72
|
from .resize_inference_partition import ResizeInferencePartition
|
|
@@ -106,13 +108,6 @@ class GQLClient(BaseClientOpenTelemetry):
|
|
|
106
108
|
data = self.get_data(response)
|
|
107
109
|
return UnlinkMetric.model_validate(data)
|
|
108
110
|
|
|
109
|
-
def attach_model_to_use_case(self, input: AttachModel, **kwargs: Any) -> AttachModelToUseCase:
|
|
110
|
-
query = gql('\n mutation AttachModelToUseCase($input: AttachModel!) {\n attachModel(input: $input) {\n ...ModelServiceData\n }\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n attached\n isDefault\n desiredOnline\n createdAt\n }\n ')
|
|
111
|
-
variables: Dict[str, object] = {'input': input}
|
|
112
|
-
response = self.execute(query=query, operation_name='AttachModelToUseCase', variables=variables, **kwargs)
|
|
113
|
-
data = self.get_data(response)
|
|
114
|
-
return AttachModelToUseCase.model_validate(data)
|
|
115
|
-
|
|
116
111
|
def add_external_model(self, input: AddExternalModelInput, **kwargs: Any) -> AddExternalModel:
|
|
117
112
|
query = gql('\n mutation AddExternalModel($input: AddExternalModelInput!) {\n addExternalModel(input: $input) {\n ...ModelData\n }\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n ')
|
|
118
113
|
variables: Dict[str, object] = {'input': input}
|
|
@@ -128,7 +123,7 @@ class GQLClient(BaseClientOpenTelemetry):
|
|
|
128
123
|
return AddModel.model_validate(data)
|
|
129
124
|
|
|
130
125
|
def update_model(self, input: UpdateModelService, **kwargs: Any) -> UpdateModel:
|
|
131
|
-
query = gql('\n mutation UpdateModel($input: UpdateModelService!) {\n updateModelService(input: $input) {\n ...ModelServiceData\n }\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n
|
|
126
|
+
query = gql('\n mutation UpdateModel($input: UpdateModelService!) {\n updateModelService(input: $input) {\n ...ModelServiceData\n }\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n isDefault\n desiredOnline\n createdAt\n }\n ')
|
|
132
127
|
variables: Dict[str, object] = {'input': input}
|
|
133
128
|
response = self.execute(query=query, operation_name='UpdateModel', variables=variables, **kwargs)
|
|
134
129
|
data = self.get_data(response)
|
|
@@ -142,14 +137,14 @@ class GQLClient(BaseClientOpenTelemetry):
|
|
|
142
137
|
return TerminateModel.model_validate(data)
|
|
143
138
|
|
|
144
139
|
def create_use_case(self, input: UseCaseCreate, **kwargs: Any) -> CreateUseCase:
|
|
145
|
-
query = gql('\n mutation CreateUseCase($input: UseCaseCreate!) {\n createUseCase(input: $input) {\n ...UseCaseData\n }\n }\n\n fragment MetricWithContextData on MetricWithContext {\n id\n key\n name\n kind\n description\n scoringType\n createdAt\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n
|
|
140
|
+
query = gql('\n mutation CreateUseCase($input: UseCaseCreate!) {\n createUseCase(input: $input) {\n ...UseCaseData\n }\n }\n\n fragment MetricWithContextData on MetricWithContext {\n id\n key\n name\n kind\n description\n scoringType\n createdAt\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n isDefault\n desiredOnline\n createdAt\n }\n\n fragment UseCaseData on UseCase {\n id\n key\n name\n description\n createdAt\n metrics {\n ...MetricWithContextData\n }\n modelServices {\n ...ModelServiceData\n }\n permissions\n shares {\n team {\n id\n key\n name\n createdAt\n }\n role {\n id\n key\n name\n createdAt\n permissions\n }\n isOwner\n }\n }\n ')
|
|
146
141
|
variables: Dict[str, object] = {'input': input}
|
|
147
142
|
response = self.execute(query=query, operation_name='CreateUseCase', variables=variables, **kwargs)
|
|
148
143
|
data = self.get_data(response)
|
|
149
144
|
return CreateUseCase.model_validate(data)
|
|
150
145
|
|
|
151
146
|
def share_use_case(self, id_or_key: str, input: UseCaseShares, **kwargs: Any) -> ShareUseCase:
|
|
152
|
-
query = gql('\n mutation ShareUseCase($idOrKey: IdOrKey!, $input: UseCaseShares!) {\n shareUseCase(idOrKey: $idOrKey, input: $input) {\n ...UseCaseData\n }\n }\n\n fragment MetricWithContextData on MetricWithContext {\n id\n key\n name\n kind\n description\n scoringType\n createdAt\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n
|
|
147
|
+
query = gql('\n mutation ShareUseCase($idOrKey: IdOrKey!, $input: UseCaseShares!) {\n shareUseCase(idOrKey: $idOrKey, input: $input) {\n ...UseCaseData\n }\n }\n\n fragment MetricWithContextData on MetricWithContext {\n id\n key\n name\n kind\n description\n scoringType\n createdAt\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n isDefault\n desiredOnline\n createdAt\n }\n\n fragment UseCaseData on UseCase {\n id\n key\n name\n description\n createdAt\n metrics {\n ...MetricWithContextData\n }\n modelServices {\n ...ModelServiceData\n }\n permissions\n shares {\n team {\n id\n key\n name\n createdAt\n }\n role {\n id\n key\n name\n createdAt\n permissions\n }\n isOwner\n }\n }\n ')
|
|
153
148
|
variables: Dict[str, object] = {'idOrKey': id_or_key, 'input': input}
|
|
154
149
|
response = self.execute(query=query, operation_name='ShareUseCase', variables=variables, **kwargs)
|
|
155
150
|
data = self.get_data(response)
|
|
@@ -184,7 +179,7 @@ class GQLClient(BaseClientOpenTelemetry):
|
|
|
184
179
|
return DeleteDataset.model_validate(data)
|
|
185
180
|
|
|
186
181
|
def add_hf_model(self, input: AddHFModelInput, **kwargs: Any) -> AddHFModel:
|
|
187
|
-
query = gql('\n mutation AddHFModel($input: AddHFModelInput!) {\n importHfModel(input: $input) {\n ...JobData\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
182
|
+
query = gql('\n mutation AddHFModel($input: AddHFModelInput!) {\n importHfModel(input: $input) {\n ...JobData\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n status\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
188
183
|
variables: Dict[str, object] = {'input': input}
|
|
189
184
|
response = self.execute(query=query, operation_name='AddHFModel', variables=variables, **kwargs)
|
|
190
185
|
data = self.get_data(response)
|
|
@@ -338,14 +333,14 @@ class GQLClient(BaseClientOpenTelemetry):
|
|
|
338
333
|
return RemoveTeamMember.model_validate(data)
|
|
339
334
|
|
|
340
335
|
def create_job(self, input: JobInput, **kwargs: Any) -> CreateJob:
|
|
341
|
-
query = gql('\n mutation CreateJob($input: JobInput!) {\n createJob(input: $input) {\n ...JobData\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
336
|
+
query = gql('\n mutation CreateJob($input: JobInput!) {\n createJob(input: $input) {\n ...JobData\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n status\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
342
337
|
variables: Dict[str, object] = {'input': input}
|
|
343
338
|
response = self.execute(query=query, operation_name='CreateJob', variables=variables, **kwargs)
|
|
344
339
|
data = self.get_data(response)
|
|
345
340
|
return CreateJob.model_validate(data)
|
|
346
341
|
|
|
347
342
|
def cancel_job(self, job_id: Any, **kwargs: Any) -> CancelJob:
|
|
348
|
-
query = gql('\n mutation CancelJob($jobId: UUID!) {\n cancelJob(id: $jobId) {\n ...JobData\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
343
|
+
query = gql('\n mutation CancelJob($jobId: UUID!) {\n cancelJob(id: $jobId) {\n ...JobData\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n status\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
349
344
|
variables: Dict[str, object] = {'jobId': job_id}
|
|
350
345
|
response = self.execute(query=query, operation_name='CancelJob', variables=variables, **kwargs)
|
|
351
346
|
data = self.get_data(response)
|
|
@@ -365,6 +360,27 @@ class GQLClient(BaseClientOpenTelemetry):
|
|
|
365
360
|
data = self.get_data(response)
|
|
366
361
|
return ResizeInferencePartition.model_validate(data)
|
|
367
362
|
|
|
363
|
+
def add_model_to_use_case(self, input: AddModelToUseCaseInput, **kwargs: Any) -> AddModelToUseCase:
|
|
364
|
+
query = gql('\n mutation AddModelToUseCase($input: AddModelToUseCaseInput!) {\n addModelToUseCase(input: $input)\n }\n ')
|
|
365
|
+
variables: Dict[str, object] = {'input': input}
|
|
366
|
+
response = self.execute(query=query, operation_name='AddModelToUseCase', variables=variables, **kwargs)
|
|
367
|
+
data = self.get_data(response)
|
|
368
|
+
return AddModelToUseCase.model_validate(data)
|
|
369
|
+
|
|
370
|
+
def remove_model_from_use_case(self, input: RemoveModelFromUseCaseInput, **kwargs: Any) -> RemoveModelFromUseCase:
|
|
371
|
+
query = gql('\n mutation RemoveModelFromUseCase($input: RemoveModelFromUseCaseInput!) {\n removeModelFromUseCase(input: $input)\n }\n ')
|
|
372
|
+
variables: Dict[str, object] = {'input': input}
|
|
373
|
+
response = self.execute(query=query, operation_name='RemoveModelFromUseCase', variables=variables, **kwargs)
|
|
374
|
+
data = self.get_data(response)
|
|
375
|
+
return RemoveModelFromUseCase.model_validate(data)
|
|
376
|
+
|
|
377
|
+
def deploy_model(self, input: DeployModelInput, **kwargs: Any) -> DeployModel:
|
|
378
|
+
query = gql('\n mutation DeployModel($input: DeployModelInput!) {\n deployModel(input: $input) {\n ...ModelServiceData\n }\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n isDefault\n desiredOnline\n createdAt\n }\n ')
|
|
379
|
+
variables: Dict[str, object] = {'input': input}
|
|
380
|
+
response = self.execute(query=query, operation_name='DeployModel', variables=variables, **kwargs)
|
|
381
|
+
data = self.get_data(response)
|
|
382
|
+
return DeployModel.model_validate(data)
|
|
383
|
+
|
|
368
384
|
def list_datasets(self, input: str, **kwargs: Any) -> ListDatasets:
|
|
369
385
|
query = gql('\n query ListDatasets($input: IdOrKey!) {\n datasets(useCase: $input) {\n ...DatasetData\n }\n }\n\n fragment DatasetData on Dataset {\n id\n key\n name\n createdAt\n kind\n records\n metricsUsage {\n feedbackCount\n comparisonCount\n metric {\n ...MetricData\n }\n }\n }\n\n fragment MetricData on Metric {\n id\n key\n name\n kind\n description\n scoringType\n createdAt\n hasDirectFeedbacks\n hasComparisonFeedbacks\n }\n ')
|
|
370
386
|
variables: Dict[str, object] = {'input': input}
|
|
@@ -380,14 +396,14 @@ class GQLClient(BaseClientOpenTelemetry):
|
|
|
380
396
|
return DescribeDataset.model_validate(data)
|
|
381
397
|
|
|
382
398
|
def describe_use_case(self, input: str, **kwargs: Any) -> DescribeUseCase:
|
|
383
|
-
query = gql('\n query DescribeUseCase($input: IdOrKey!) {\n useCase(idOrKey: $input) {\n ...UseCaseData\n }\n }\n\n fragment MetricWithContextData on MetricWithContext {\n id\n key\n name\n kind\n description\n scoringType\n createdAt\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n
|
|
399
|
+
query = gql('\n query DescribeUseCase($input: IdOrKey!) {\n useCase(idOrKey: $input) {\n ...UseCaseData\n }\n }\n\n fragment MetricWithContextData on MetricWithContext {\n id\n key\n name\n kind\n description\n scoringType\n createdAt\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n isDefault\n desiredOnline\n createdAt\n }\n\n fragment UseCaseData on UseCase {\n id\n key\n name\n description\n createdAt\n metrics {\n ...MetricWithContextData\n }\n modelServices {\n ...ModelServiceData\n }\n permissions\n shares {\n team {\n id\n key\n name\n createdAt\n }\n role {\n id\n key\n name\n createdAt\n permissions\n }\n isOwner\n }\n }\n ')
|
|
384
400
|
variables: Dict[str, object] = {'input': input}
|
|
385
401
|
response = self.execute(query=query, operation_name='DescribeUseCase', variables=variables, **kwargs)
|
|
386
402
|
data = self.get_data(response)
|
|
387
403
|
return DescribeUseCase.model_validate(data)
|
|
388
404
|
|
|
389
405
|
def list_use_cases(self, **kwargs: Any) -> ListUseCases:
|
|
390
|
-
query = gql('\n query ListUseCases {\n useCases {\n ...UseCaseData\n }\n }\n\n fragment MetricWithContextData on MetricWithContext {\n id\n key\n name\n kind\n description\n scoringType\n createdAt\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n
|
|
406
|
+
query = gql('\n query ListUseCases {\n useCases {\n ...UseCaseData\n }\n }\n\n fragment MetricWithContextData on MetricWithContext {\n id\n key\n name\n kind\n description\n scoringType\n createdAt\n }\n\n fragment ModelData on Model {\n id\n key\n name\n online\n error\n isExternal\n providerName\n isAdapter\n isTraining\n createdAt\n size\n computeConfig {\n tp\n kvCacheLen\n maxSeqLen\n }\n }\n\n fragment ModelServiceData on ModelService {\n id\n key\n name\n model {\n ...ModelData\n backbone {\n ...ModelData\n }\n }\n isDefault\n desiredOnline\n createdAt\n }\n\n fragment UseCaseData on UseCase {\n id\n key\n name\n description\n createdAt\n metrics {\n ...MetricWithContextData\n }\n modelServices {\n ...ModelServiceData\n }\n permissions\n shares {\n team {\n id\n key\n name\n createdAt\n }\n role {\n id\n key\n name\n createdAt\n permissions\n }\n isOwner\n }\n }\n ')
|
|
391
407
|
variables: Dict[str, object] = {}
|
|
392
408
|
response = self.execute(query=query, operation_name='ListUseCases', variables=variables, **kwargs)
|
|
393
409
|
data = self.get_data(response)
|
|
@@ -576,14 +592,14 @@ class GQLClient(BaseClientOpenTelemetry):
|
|
|
576
592
|
return ListGraders.model_validate(data)
|
|
577
593
|
|
|
578
594
|
def list_jobs(self, page: CursorPageInput, filter: Union[Optional[ListJobsFilterInput], UnsetType]=UNSET, order: Union[Optional[List[OrderPair]], UnsetType]=UNSET, **kwargs: Any) -> ListJobs:
|
|
579
|
-
query = gql('\n query ListJobs($page: CursorPageInput!, $filter: ListJobsFilterInput, $order: [OrderPair!]) {\n jobs(page: $page, filter: $filter, order: $order) {\n totalCount\n pageInfo {\n hasNextPage\n hasPreviousPage\n startCursor\n endCursor\n }\n nodes {\n ...JobData\n }\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
595
|
+
query = gql('\n query ListJobs($page: CursorPageInput!, $filter: ListJobsFilterInput, $order: [OrderPair!]) {\n jobs(page: $page, filter: $filter, order: $order) {\n totalCount\n pageInfo {\n hasNextPage\n hasPreviousPage\n startCursor\n endCursor\n }\n nodes {\n ...JobData\n }\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n status\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
580
596
|
variables: Dict[str, object] = {'page': page, 'filter': filter, 'order': order}
|
|
581
597
|
response = self.execute(query=query, operation_name='ListJobs', variables=variables, **kwargs)
|
|
582
598
|
data = self.get_data(response)
|
|
583
599
|
return ListJobs.model_validate(data)
|
|
584
600
|
|
|
585
601
|
def describe_job(self, id: Any, **kwargs: Any) -> DescribeJob:
|
|
586
|
-
query = gql('\n query DescribeJob($id: UUID!) {\n job(id: $id) {\n ...JobData\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
602
|
+
query = gql('\n query DescribeJob($id: UUID!) {\n job(id: $id) {\n ...JobData\n }\n }\n\n fragment CustomRecipeData on CustomRecipe {\n id\n key\n name\n content\n contentHash\n editable\n global\n builtin\n inputSchema\n jsonSchema\n description\n labels {\n key\n value\n }\n createdAt\n updatedAt\n createdBy {\n id\n name\n email\n }\n }\n\n fragment JobData on Job {\n id\n name\n status\n createdAt\n createdBy {\n id\n name\n }\n startedAt\n endedAt\n durationMs\n progress\n error\n kind\n stages {\n name\n status\n info {\n __typename\n ... on TrainingJobStageOutput {\n monitoringLink\n totalNumSamples\n processedNumSamples\n checkpoints\n }\n ... on EvalJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n ... on BatchInferenceJobStageOutput {\n totalNumSamples\n processedNumSamples\n }\n }\n }\n useCase {\n id\n key\n name\n }\n recipe {\n ...CustomRecipeData\n }\n details {\n args\n recipeHash\n artifacts {\n id\n name\n kind\n status\n uri\n metadata\n createdAt\n byproducts {\n __typename\n ... on EvaluationByproducts {\n evalResults {\n mean\n min\n max\n stddev\n count\n sum\n feedbackCount\n jobId\n artifactId\n modelService {\n key\n name\n }\n metric {\n key\n name\n }\n }\n }\n }\n }\n }\n }\n ')
|
|
587
603
|
variables: Dict[str, object] = {'id': id}
|
|
588
604
|
response = self.execute(query=query, operation_name='DescribeJob', variables=variables, **kwargs)
|
|
589
605
|
data = self.get_data(response)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional, Union
|
|
2
2
|
from .base_operation import GraphQLField
|
|
3
3
|
from .custom_typing_fields import AbcampaignGraphQLField, AbReportGraphQLField, AbVariantReportComparisonGraphQLField, AbVariantReportGraphQLField, ActivityGraphQLField, ActivityOutputGraphQLField, ApiKeyGraphQLField, ArtifactByproductsUnion, AuthProviderGraphQLField, BatchInferenceJobStageOutputGraphQLField, BillingUsageGraphQLField, ChatMessageGraphQLField, ComparisonFeedbackGraphQLField, CompletionConnectionGraphQLField, CompletionEdgeGraphQLField, CompletionGraphQLField, CompletionGroupDataConnectionGraphQLField, CompletionGroupDataEdgeGraphQLField, CompletionGroupDataGraphQLField, CompletionGroupFeedbackStatsGraphQLField, CompletionHistoryEntryOuputGraphQLField, CompletionLabelGraphQLField, CompletionMetadataGraphQLField, ComputePoolGraphQLField, ContractGraphQLField, CustomConfigOutputGraphQLField, CustomRecipeGraphQLField, CustomRecipeJobDetailsGraphQLField, DatasetByproductsGraphQLField, DatasetGraphQLField, DatasetMetricUsageGraphQLField, DatasetProgressGraphQLField, DatasetUploadProcessingStatusGraphQLField, DatasetValidationOutputGraphQLField, DeleteConfirmGraphQLField, DirectFeedbackGraphQLField, EmojiGraphQLField, EvalJobStageOutputGraphQLField, EvaluationByproductsGraphQLField, EvaluationResultGraphQLField, GlobalUsageGraphQLField, GpuAllocationGraphQLField, GraderConfigUnion, GraderGraphQLField, HarmonyGroupGraphQLField, InteractionOutputGraphQLField, IntervalGraphQLField, JobArtifactGraphQLField, JobConnectionGraphQLField, JobEdgeGraphQLField, JobGraphQLField, JobStageInfoOutputUnion, JobStageOutputGraphQLField, JudgeConfigOutputGraphQLField, JudgeExampleGraphQLField, JudgeGraphQLField, LabelGraphQLField, LabelKeyUsageGraphQLField, LabelUsageGraphQLField, LabelValueUsageGraphQLField, MetaObjectGraphQLField, MetricActivityGraphQLField, MetricGraphQLField, MetricWithContextGraphQLField, ModelByproductsGraphQLField, ModelComputeConfigOutputGraphQLField, ModelGraphQLField, ModelPlacementOutputGraphQLField, ModelServiceGraphQLField, PageInfoGraphQLField, PrebuiltConfigDefinitionGraphQLField, PrebuiltConfigOutputGraphQLField, PrebuiltCriteriaGraphQLField, ProviderListGraphQLField, RemoteConfigOutputGraphQLField, RemoteEnvGraphQLField, RemoteEnvTestOfflineGraphQLField, RemoteEnvTestOnlineGraphQLField, RoleGraphQLField, SearchResultGraphQLField, SessionGraphQLField, SettingsGraphQLField, ShareGraphQLField, SystemPromptTemplateGraphQLField, TeamGraphQLField, TeamMemberGraphQLField, TeamWithroleGraphQLField, TimeseriesGraphQLField, ToolProviderGraphQLField, TrainingJobStageOutputGraphQLField, TrendResultGraphQLField, UnitConfigGraphQLField, UsageAggregateItemGraphQLField, UsageAggregatePerUseCaseItemGraphQLField, UsageGraphQLField, UsageStatsByModelGraphQLField, UsageStatsGraphQLField, UseCaseGraphQLField, UseCaseItemGraphQLField, UseCaseMetadataGraphQLField, UserGraphQLField, WidgetGraphQLField
|
|
4
|
-
from .input_types import AbCampaignFilter, ArtifactFilter, CursorPageInput, FeedbackFilterInput, ListCompletionsFilterInput, MetricTrendInput, ModelServiceFilter, OrderPair, TimeRange, TimeseriesInput, UseCaseFilter
|
|
4
|
+
from .input_types import AbCampaignFilter, ArtifactFilter, CursorPageInput, FeedbackFilterInput, ListCompletionsFilterInput, MetricTrendInput, ModelFilter, ModelServiceFilter, OrderPair, TimeRange, TimeseriesInput, UseCaseFilter
|
|
5
5
|
|
|
6
6
|
class AbReportFields(GraphQLField):
|
|
7
7
|
"""@private"""
|
|
@@ -629,6 +629,7 @@ class CustomRecipeJobDetailsFields(GraphQLField):
|
|
|
629
629
|
return JobArtifactFields('artifacts')
|
|
630
630
|
num_gpus: 'CustomRecipeJobDetailsGraphQLField' = CustomRecipeJobDetailsGraphQLField('numGpus')
|
|
631
631
|
gpu_duration_ms: 'CustomRecipeJobDetailsGraphQLField' = CustomRecipeJobDetailsGraphQLField('gpuDurationMs')
|
|
632
|
+
compute_pool_id: 'CustomRecipeJobDetailsGraphQLField' = CustomRecipeJobDetailsGraphQLField('computePoolId')
|
|
632
633
|
|
|
633
634
|
def fields(self, *subfields: Union[CustomRecipeJobDetailsGraphQLField, 'JobArtifactFields']) -> 'CustomRecipeJobDetailsFields':
|
|
634
635
|
"""Subfields should come from the CustomRecipeJobDetailsFields class"""
|
|
@@ -658,8 +659,13 @@ class DatasetFields(GraphQLField):
|
|
|
658
659
|
@classmethod
|
|
659
660
|
def progress(cls) -> 'DatasetProgressFields':
|
|
660
661
|
return DatasetProgressFields('progress')
|
|
662
|
+
download_url: 'DatasetGraphQLField' = DatasetGraphQLField('downloadUrl')
|
|
661
663
|
|
|
662
|
-
|
|
664
|
+
@classmethod
|
|
665
|
+
def use_case(cls) -> 'UseCaseFields':
|
|
666
|
+
return UseCaseFields('useCase')
|
|
667
|
+
|
|
668
|
+
def fields(self, *subfields: Union[DatasetGraphQLField, 'DatasetMetricUsageFields', 'DatasetProgressFields', 'UseCaseFields']) -> 'DatasetFields':
|
|
663
669
|
"""Subfields should come from the DatasetFields class"""
|
|
664
670
|
self._subfields.extend(subfields)
|
|
665
671
|
return self
|
|
@@ -1447,6 +1453,8 @@ class ModelFields(GraphQLField):
|
|
|
1447
1453
|
return ModelFields('parent')
|
|
1448
1454
|
is_training: 'ModelGraphQLField' = ModelGraphQLField('isTraining')
|
|
1449
1455
|
'indicates if a training is pending or running for this model'
|
|
1456
|
+
is_published: 'ModelGraphQLField' = ModelGraphQLField('isPublished')
|
|
1457
|
+
is_stable: 'ModelGraphQLField' = ModelGraphQLField('isStable')
|
|
1450
1458
|
capabilities: 'ModelGraphQLField' = ModelGraphQLField('capabilities')
|
|
1451
1459
|
supported_tp: 'ModelGraphQLField' = ModelGraphQLField('supportedTp')
|
|
1452
1460
|
family: 'ModelGraphQLField' = ModelGraphQLField('family')
|
|
@@ -1524,9 +1532,10 @@ class ModelServiceFields(GraphQLField):
|
|
|
1524
1532
|
@classmethod
|
|
1525
1533
|
def model(cls) -> 'ModelFields':
|
|
1526
1534
|
return ModelFields('model')
|
|
1527
|
-
attached: 'ModelServiceGraphQLField' = ModelServiceGraphQLField('attached')
|
|
1528
1535
|
is_default: 'ModelServiceGraphQLField' = ModelServiceGraphQLField('isDefault')
|
|
1529
1536
|
desired_online: 'ModelServiceGraphQLField' = ModelServiceGraphQLField('desiredOnline')
|
|
1537
|
+
deleted: 'ModelServiceGraphQLField' = ModelServiceGraphQLField('deleted')
|
|
1538
|
+
'Whether or not this model service has been deleted.'
|
|
1530
1539
|
|
|
1531
1540
|
@classmethod
|
|
1532
1541
|
def activity(cls, *, timerange: Optional[TimeRange]=None) -> 'ActivityFields':
|
|
@@ -2092,6 +2101,13 @@ class UseCaseFields(GraphQLField):
|
|
|
2092
2101
|
cleared_arguments = {key: value for key, value in arguments.items() if value['value'] is not None}
|
|
2093
2102
|
return ModelServiceFields('modelService', arguments=cleared_arguments)
|
|
2094
2103
|
|
|
2104
|
+
@classmethod
|
|
2105
|
+
def models(cls, *, filter: Optional[ModelFilter]=None) -> 'ModelFields':
|
|
2106
|
+
"""Returns models associated with this use case."""
|
|
2107
|
+
arguments: Dict[str, Dict[str, Any]] = {'filter': {'type': 'ModelFilter', 'value': filter}}
|
|
2108
|
+
cleared_arguments = {key: value for key, value in arguments.items() if value['value'] is not None}
|
|
2109
|
+
return ModelFields('models', arguments=cleared_arguments)
|
|
2110
|
+
|
|
2095
2111
|
@classmethod
|
|
2096
2112
|
def default_model_service(cls) -> 'ModelServiceFields':
|
|
2097
2113
|
return ModelServiceFields('defaultModelService')
|
|
@@ -2143,7 +2159,7 @@ class UseCaseFields(GraphQLField):
|
|
|
2143
2159
|
def tool_providers(cls) -> 'ToolProviderFields':
|
|
2144
2160
|
return ToolProviderFields('toolProviders')
|
|
2145
2161
|
|
|
2146
|
-
def fields(self, *subfields: Union[UseCaseGraphQLField, 'AbcampaignFields', 'ActivityFields', 'LabelUsageFields', 'MetricWithContextFields', 'ModelServiceFields', 'SettingsFields', 'ShareFields', 'ToolProviderFields', 'UseCaseMetadataFields', 'WidgetFields']) -> 'UseCaseFields':
|
|
2162
|
+
def fields(self, *subfields: Union[UseCaseGraphQLField, 'AbcampaignFields', 'ActivityFields', 'LabelUsageFields', 'MetricWithContextFields', 'ModelFields', 'ModelServiceFields', 'SettingsFields', 'ShareFields', 'ToolProviderFields', 'UseCaseMetadataFields', 'WidgetFields']) -> 'UseCaseFields':
|
|
2147
2163
|
"""Subfields should come from the UseCaseFields class"""
|
|
2148
2164
|
self._subfields.extend(subfields)
|
|
2149
2165
|
return self
|
|
@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
|
|
|
2
2
|
from .base_model import Upload
|
|
3
3
|
from .custom_fields import AbcampaignFields, ApiKeyFields, CompletionFields, CustomRecipeFields, DatasetFields, DatasetUploadProcessingStatusFields, DatasetValidationOutputFields, DeleteConfirmFields, DirectFeedbackFields, GraderFields, JobFields, JudgeFields, MetricFields, MetricWithContextFields, ModelFields, ModelServiceFields, RemoteEnvFields, RoleFields, SystemPromptTemplateFields, TeamFields, TeamMemberFields, ToolProviderFields, UseCaseFields, UserFields
|
|
4
4
|
from .custom_typing_fields import GraphQLField, RemoteEnvTestUnion
|
|
5
|
-
from .input_types import AbcampaignCreate, AddExternalModelInput, AddHFModelInput, AddModelInput,
|
|
5
|
+
from .input_types import AbcampaignCreate, AddExternalModelInput, AddHFModelInput, AddModelInput, AddModelToUseCaseInput, ApiKeyCreate, CancelAllocationInput, CreateRecipeInput, CreateToolProviderInput, DatasetCreate, DatasetCreateFromFilters, DatasetCreateFromMultipartUpload, DeleteModelInput, DeployModelInput, FeedbackAddInput, FeedbackUpdateInput, GraderCreateInput, GraderUpdateInput, JobInput, JudgeCreate, JudgeUpdate, MetricCreate, MetricLink, MetricUnlink, ModelComputeConfigInput, PrebuiltJudgeCreate, RemoteEnvCreate, RemoveModelFromUseCaseInput, ResizePartitionInput, RoleCreate, SystemPromptTemplateCreate, SystemPromptTemplateUpdate, TeamCreate, TeamMemberRemove, TeamMemberSet, UpdateCompletion, UpdateModelInput, UpdateModelService, UpdateRecipeInput, UpdateToolProviderInput, UseCaseCreate, UseCaseShares, UseCaseUpdate, UserCreate
|
|
6
6
|
|
|
7
7
|
class Mutation:
|
|
8
8
|
"""@private"""
|
|
@@ -123,16 +123,16 @@ class Mutation:
|
|
|
123
123
|
return GraphQLField(field_name='unlinkMetric', arguments=cleared_arguments)
|
|
124
124
|
|
|
125
125
|
@classmethod
|
|
126
|
-
def
|
|
127
|
-
arguments: Dict[str, Dict[str, Any]] = {'input': {'type': '
|
|
126
|
+
def deploy_model(cls, input: DeployModelInput) -> ModelServiceFields:
|
|
127
|
+
arguments: Dict[str, Dict[str, Any]] = {'input': {'type': 'DeployModelInput!', 'value': input}}
|
|
128
128
|
cleared_arguments = {key: value for key, value in arguments.items() if value['value'] is not None}
|
|
129
|
-
return ModelServiceFields(field_name='
|
|
129
|
+
return ModelServiceFields(field_name='deployModel', arguments=cleared_arguments)
|
|
130
130
|
|
|
131
131
|
@classmethod
|
|
132
|
-
def
|
|
133
|
-
arguments: Dict[str, Dict[str, Any]] = {'input': {'type': '
|
|
132
|
+
def update_model(cls, input: UpdateModelInput) -> ModelFields:
|
|
133
|
+
arguments: Dict[str, Dict[str, Any]] = {'input': {'type': 'UpdateModelInput!', 'value': input}}
|
|
134
134
|
cleared_arguments = {key: value for key, value in arguments.items() if value['value'] is not None}
|
|
135
|
-
return
|
|
135
|
+
return ModelFields(field_name='updateModel', arguments=cleared_arguments)
|
|
136
136
|
|
|
137
137
|
@classmethod
|
|
138
138
|
def update_model_service(cls, input: UpdateModelService) -> ModelServiceFields:
|
|
@@ -142,7 +142,7 @@ class Mutation:
|
|
|
142
142
|
|
|
143
143
|
@classmethod
|
|
144
144
|
def terminate_model(cls, id_or_key: str, force: bool) -> GraphQLField:
|
|
145
|
-
"""If a model is used by several use cases with `
|
|
145
|
+
"""If a model is used by several use cases with `desiredOnline = true`, you need to specify 'force = true' to be able to deactivate the model"""
|
|
146
146
|
arguments: Dict[str, Dict[str, Any]] = {'idOrKey': {'type': 'IdOrKey!', 'value': id_or_key}, 'force': {'type': 'Boolean!', 'value': force}}
|
|
147
147
|
cleared_arguments = {key: value for key, value in arguments.items() if value['value'] is not None}
|
|
148
148
|
return GraphQLField(field_name='terminateModel', arguments=cleared_arguments)
|
|
@@ -171,6 +171,27 @@ class Mutation:
|
|
|
171
171
|
cleared_arguments = {key: value for key, value in arguments.items() if value['value'] is not None}
|
|
172
172
|
return ModelFields(field_name='updateModelComputeConfig', arguments=cleared_arguments)
|
|
173
173
|
|
|
174
|
+
@classmethod
|
|
175
|
+
def add_model_to_use_case(cls, input: AddModelToUseCaseInput) -> GraphQLField:
|
|
176
|
+
arguments: Dict[str, Dict[str, Any]] = {'input': {'type': 'AddModelToUseCaseInput!', 'value': input}}
|
|
177
|
+
cleared_arguments = {key: value for key, value in arguments.items() if value['value'] is not None}
|
|
178
|
+
return GraphQLField(field_name='addModelToUseCase', arguments=cleared_arguments)
|
|
179
|
+
|
|
180
|
+
@classmethod
|
|
181
|
+
def remove_model_from_use_case(cls, input: RemoveModelFromUseCaseInput) -> GraphQLField:
|
|
182
|
+
"""Removes a model from a use case. If the model is not bound to any other use case or published organisation wide, it is deleted from storage."""
|
|
183
|
+
arguments: Dict[str, Dict[str, Any]] = {'input': {'type': 'RemoveModelFromUseCaseInput!', 'value': input}}
|
|
184
|
+
cleared_arguments = {key: value for key, value in arguments.items() if value['value'] is not None}
|
|
185
|
+
return GraphQLField(field_name='removeModelFromUseCase', arguments=cleared_arguments)
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def delete_model(cls, input: DeleteModelInput) -> GraphQLField:
|
|
189
|
+
"""Deletes a model: removes from all use cases, unpublishes from org registry,
|
|
190
|
+
and deletes from storage."""
|
|
191
|
+
arguments: Dict[str, Dict[str, Any]] = {'input': {'type': 'DeleteModelInput!', 'value': input}}
|
|
192
|
+
cleared_arguments = {key: value for key, value in arguments.items() if value['value'] is not None}
|
|
193
|
+
return GraphQLField(field_name='deleteModel', arguments=cleared_arguments)
|
|
194
|
+
|
|
174
195
|
@classmethod
|
|
175
196
|
def add_remote_env(cls, input: RemoteEnvCreate) -> RemoteEnvFields:
|
|
176
197
|
arguments: Dict[str, Dict[str, Any]] = {'input': {'type': 'RemoteEnvCreate!', 'value': input}}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from pydantic import Field
|
|
2
|
+
from .base_model import BaseModel
|
|
3
|
+
from .fragments import ModelServiceData
|
|
4
|
+
|
|
5
|
+
class DeployModel(BaseModel):
|
|
6
|
+
"""@public"""
|
|
7
|
+
deploy_model: 'DeployModelDeployModel' = Field(alias='deployModel')
|
|
8
|
+
|
|
9
|
+
class DeployModelDeployModel(ModelServiceData):
|
|
10
|
+
"""@public"""
|
|
11
|
+
pass
|
|
12
|
+
DeployModel.model_rebuild()
|
|
@@ -67,6 +67,7 @@ class DateBucketUnit(str, Enum):
|
|
|
67
67
|
class ExternalModelProviderName(str, Enum):
|
|
68
68
|
"""@public"""
|
|
69
69
|
OPEN_AI = 'OPEN_AI'
|
|
70
|
+
LEGACY_OPEN_AI = 'LEGACY_OPEN_AI'
|
|
70
71
|
GOOGLE = 'GOOGLE'
|
|
71
72
|
ANTHROPIC = 'ANTHROPIC'
|
|
72
73
|
|
|
@@ -161,33 +162,14 @@ class ModelOnline(str, Enum):
|
|
|
161
162
|
OFFLINE = 'OFFLINE'
|
|
162
163
|
ERROR = 'ERROR'
|
|
163
164
|
|
|
164
|
-
class
|
|
165
|
+
class ModelServiceStatus(str, Enum):
|
|
165
166
|
"""@public"""
|
|
166
167
|
PENDING = 'PENDING'
|
|
167
168
|
ONLINE = 'ONLINE'
|
|
168
169
|
OFFLINE = 'OFFLINE'
|
|
169
|
-
DETACHED = 'DETACHED'
|
|
170
170
|
TURNED_OFF = 'TURNED_OFF'
|
|
171
171
|
ERROR = 'ERROR'
|
|
172
172
|
|
|
173
|
-
class OpenAIModel(str, Enum):
|
|
174
|
-
"""@public"""
|
|
175
|
-
GPT41 = 'GPT41'
|
|
176
|
-
GPT4O = 'GPT4O'
|
|
177
|
-
GPT4O_MINI = 'GPT4O_MINI'
|
|
178
|
-
O1 = 'O1'
|
|
179
|
-
O1_MINI = 'O1_MINI'
|
|
180
|
-
O3_MINI = 'O3_MINI'
|
|
181
|
-
O4_MINI = 'O4_MINI'
|
|
182
|
-
O3 = 'O3'
|
|
183
|
-
GPT4 = 'GPT4'
|
|
184
|
-
GPT4_TURBO = 'GPT4_TURBO'
|
|
185
|
-
GPT3_5_TURBO = 'GPT3_5_TURBO'
|
|
186
|
-
GPT5 = 'GPT5'
|
|
187
|
-
GPT5_MINI = 'GPT5_MINI'
|
|
188
|
-
GPT5_NANO = 'GPT5_NANO'
|
|
189
|
-
GPT41_MINI = 'GPT41_MINI'
|
|
190
|
-
|
|
191
173
|
class PrebuiltCriteriaKey(str, Enum):
|
|
192
174
|
"""@public"""
|
|
193
175
|
FAITHFULNESS = 'FAITHFULNESS'
|
|
@@ -201,6 +183,7 @@ class Protocol(str, Enum):
|
|
|
201
183
|
class ProviderName(str, Enum):
|
|
202
184
|
"""@public"""
|
|
203
185
|
OPEN_AI = 'OPEN_AI'
|
|
186
|
+
LEGACY_OPEN_AI = 'LEGACY_OPEN_AI'
|
|
204
187
|
HARMONY = 'HARMONY'
|
|
205
188
|
GOOGLE = 'GOOGLE'
|
|
206
189
|
ANTHROPIC = 'ANTHROPIC'
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Annotated, Any, List, Literal, Optional, Union
|
|
2
2
|
from pydantic import Field
|
|
3
3
|
from .base_model import BaseModel
|
|
4
|
-
from .enums import AbcampaignStatus, CompletionSource, DatasetKind, FeedbackType, GraderTypeEnum, HarmonyStatus, JobArtifactKind, JobKind, JobStatus, JobStatusOutput, JudgeCapability, MetricKind, MetricScoringType, ModelOnline, PrebuiltCriteriaKey, ProviderName, RemoteEnvStatus
|
|
4
|
+
from .enums import AbcampaignStatus, CompletionSource, DatasetKind, FeedbackType, GraderTypeEnum, HarmonyStatus, JobArtifactKind, JobArtifactStatus, JobKind, JobStatus, JobStatusOutput, JudgeCapability, MetricKind, MetricScoringType, ModelOnline, PrebuiltCriteriaKey, ProviderName, RemoteEnvStatus
|
|
5
5
|
|
|
6
6
|
class AbCampaignCreateData(BaseModel):
|
|
7
7
|
"""@public"""
|
|
@@ -246,7 +246,7 @@ class ModelData(BaseModel):
|
|
|
246
246
|
is_adapter: bool = Field(alias='isAdapter')
|
|
247
247
|
is_training: bool = Field(alias='isTraining')
|
|
248
248
|
created_at: int = Field(alias='createdAt')
|
|
249
|
-
size: Optional[
|
|
249
|
+
size: Optional[int]
|
|
250
250
|
compute_config: Optional['ModelDataComputeConfig'] = Field(alias='computeConfig')
|
|
251
251
|
|
|
252
252
|
class ModelDataComputeConfig(BaseModel):
|
|
@@ -431,6 +431,7 @@ class JobDataDetailsArtifacts(BaseModel):
|
|
|
431
431
|
id: Any
|
|
432
432
|
name: str
|
|
433
433
|
kind: JobArtifactKind
|
|
434
|
+
status: JobArtifactStatus
|
|
434
435
|
uri: Optional[str]
|
|
435
436
|
metadata: Any
|
|
436
437
|
created_at: int = Field(alias='createdAt')
|
|
@@ -584,7 +585,7 @@ class ModelDataAdmin(BaseModel):
|
|
|
584
585
|
is_adapter: bool = Field(alias='isAdapter')
|
|
585
586
|
is_training: bool = Field(alias='isTraining')
|
|
586
587
|
created_at: int = Field(alias='createdAt')
|
|
587
|
-
size: Optional[
|
|
588
|
+
size: Optional[int]
|
|
588
589
|
|
|
589
590
|
class ModelDataAdminUseCases(BaseModel):
|
|
590
591
|
"""@public"""
|
|
@@ -598,7 +599,6 @@ class ModelServiceData(BaseModel):
|
|
|
598
599
|
key: str
|
|
599
600
|
name: str
|
|
600
601
|
model: 'ModelServiceDataModel'
|
|
601
|
-
attached: bool
|
|
602
602
|
is_default: bool = Field(alias='isDefault')
|
|
603
603
|
desired_online: bool = Field(alias='desiredOnline')
|
|
604
604
|
created_at: int = Field(alias='createdAt')
|