clarifai 9.8.2__py3-none-any.whl → 9.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/app.py +17 -0
- clarifai/client/search.py +173 -0
- clarifai/client/workflow.py +1 -1
- clarifai/constants/search.py +2 -0
- clarifai/models/model_serving/README.md +3 -3
- clarifai/models/model_serving/docs/dependencies.md +5 -10
- clarifai/models/model_serving/examples/image_classification/age_vit/requirements.txt +1 -0
- clarifai/models/model_serving/examples/text_classification/xlm-roberta/requirements.txt +1 -0
- clarifai/models/model_serving/examples/text_to_image/sd-v1.5/requirements.txt +1 -0
- clarifai/models/model_serving/examples/text_to_text/bart-summarize/requirements.txt +1 -0
- clarifai/models/model_serving/examples/visual_detection/yolov5x/requirements.txt +1 -1
- clarifai/models/model_serving/examples/visual_embedding/vit-base/requirements.txt +1 -0
- clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/requirements.txt +1 -0
- clarifai/models/model_serving/pb_model_repository.py +1 -3
- clarifai/schema/search.py +60 -0
- clarifai/versions.py +1 -1
- clarifai/workflows/export.py +9 -8
- clarifai/workflows/utils.py +1 -1
- clarifai/workflows/validate.py +1 -1
- {clarifai-9.8.2.dist-info → clarifai-9.9.0.dist-info}/METADATA +1 -1
- {clarifai-9.8.2.dist-info → clarifai-9.9.0.dist-info}/RECORD +44 -56
- clarifai_utils/client/app.py +17 -0
- clarifai_utils/client/search.py +173 -0
- clarifai_utils/client/workflow.py +1 -1
- clarifai_utils/constants/search.py +2 -0
- clarifai_utils/models/model_serving/README.md +3 -3
- clarifai_utils/models/model_serving/docs/dependencies.md +5 -10
- clarifai_utils/models/model_serving/examples/image_classification/age_vit/requirements.txt +1 -0
- clarifai_utils/models/model_serving/examples/text_classification/xlm-roberta/requirements.txt +1 -0
- clarifai_utils/models/model_serving/examples/text_to_image/sd-v1.5/requirements.txt +1 -0
- clarifai_utils/models/model_serving/examples/text_to_text/bart-summarize/requirements.txt +1 -0
- clarifai_utils/models/model_serving/examples/visual_detection/yolov5x/requirements.txt +1 -1
- clarifai_utils/models/model_serving/examples/visual_embedding/vit-base/requirements.txt +1 -0
- clarifai_utils/models/model_serving/examples/visual_segmentation/segformer-b2/requirements.txt +1 -0
- clarifai_utils/models/model_serving/pb_model_repository.py +1 -3
- clarifai_utils/schema/search.py +60 -0
- clarifai_utils/versions.py +1 -1
- clarifai_utils/workflows/export.py +9 -8
- clarifai_utils/workflows/utils.py +1 -1
- clarifai_utils/workflows/validate.py +1 -1
- clarifai/models/model_serving/envs/triton_conda-cp3.8-torch1.13.1-19f97078.yaml +0 -35
- clarifai/models/model_serving/envs/triton_conda-cp3.8-torch2.0.0-ce980f28.yaml +0 -51
- clarifai/models/model_serving/examples/image_classification/age_vit/triton_conda.yaml +0 -1
- clarifai/models/model_serving/examples/text_classification/xlm-roberta/triton_conda.yaml +0 -1
- clarifai/models/model_serving/examples/text_to_image/sd-v1.5/triton_conda.yaml +0 -1
- clarifai/models/model_serving/examples/text_to_text/bart-summarize/triton_conda.yaml +0 -1
- clarifai/models/model_serving/examples/visual_detection/yolov5x/triton_conda.yaml +0 -1
- clarifai/models/model_serving/examples/visual_embedding/vit-base/triton_conda.yaml +0 -1
- clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/triton_conda.yaml +0 -1
- clarifai_utils/models/model_serving/envs/triton_conda-cp3.8-torch1.13.1-19f97078.yaml +0 -35
- clarifai_utils/models/model_serving/envs/triton_conda-cp3.8-torch2.0.0-ce980f28.yaml +0 -51
- clarifai_utils/models/model_serving/examples/image_classification/age_vit/triton_conda.yaml +0 -1
- clarifai_utils/models/model_serving/examples/text_classification/xlm-roberta/triton_conda.yaml +0 -1
- clarifai_utils/models/model_serving/examples/text_to_image/sd-v1.5/triton_conda.yaml +0 -1
- clarifai_utils/models/model_serving/examples/text_to_text/bart-summarize/triton_conda.yaml +0 -1
- clarifai_utils/models/model_serving/examples/visual_detection/yolov5x/triton_conda.yaml +0 -1
- clarifai_utils/models/model_serving/examples/visual_embedding/vit-base/triton_conda.yaml +0 -1
- clarifai_utils/models/model_serving/examples/visual_segmentation/segformer-b2/triton_conda.yaml +0 -1
- {clarifai-9.8.2.dist-info → clarifai-9.9.0.dist-info}/LICENSE +0 -0
- {clarifai-9.8.2.dist-info → clarifai-9.9.0.dist-info}/WHEEL +0 -0
- {clarifai-9.8.2.dist-info → clarifai-9.9.0.dist-info}/entry_points.txt +0 -0
- {clarifai-9.8.2.dist-info → clarifai-9.9.0.dist-info}/top_level.txt +0 -0
clarifai/client/app.py
CHANGED
|
@@ -13,6 +13,7 @@ from clarifai.client.input import Inputs
|
|
|
13
13
|
from clarifai.client.lister import Lister
|
|
14
14
|
from clarifai.client.model import Model
|
|
15
15
|
from clarifai.client.module import Module
|
|
16
|
+
from clarifai.client.search import Search
|
|
16
17
|
from clarifai.client.workflow import Workflow
|
|
17
18
|
from clarifai.errors import UserError
|
|
18
19
|
from clarifai.urls.helper import ClarifaiUrlHelper
|
|
@@ -555,6 +556,22 @@ class App(Lister, BaseClient):
|
|
|
555
556
|
raise Exception(response.status)
|
|
556
557
|
self.logger.info("\nModule Deleted\n%s", response.status)
|
|
557
558
|
|
|
559
|
+
def search(self, **kwargs) -> Model:
|
|
560
|
+
"""Returns a Search object for the user and app ID.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
see the Search class in clarifai.client.search for kwargs.
|
|
564
|
+
|
|
565
|
+
Returns:
|
|
566
|
+
Search: A Search object for the user and app ID.
|
|
567
|
+
|
|
568
|
+
Example:
|
|
569
|
+
>>> from clarifai.client.app import App
|
|
570
|
+
>>> app = App(app_id="app_id", user_id="user_id")
|
|
571
|
+
>>> search_client = app.search(top_k=12, metric="euclidean")
|
|
572
|
+
"""
|
|
573
|
+
return Search(**kwargs)
|
|
574
|
+
|
|
558
575
|
def __getattr__(self, name):
|
|
559
576
|
return getattr(self.app_info, name)
|
|
560
577
|
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Generator
|
|
2
|
+
|
|
3
|
+
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
|
4
|
+
from clarifai_grpc.grpc.api.status import status_code_pb2
|
|
5
|
+
from google.protobuf.json_format import MessageToDict
|
|
6
|
+
from google.protobuf.struct_pb2 import Struct
|
|
7
|
+
from schema import SchemaError
|
|
8
|
+
|
|
9
|
+
from clarifai.client.base import BaseClient
|
|
10
|
+
from clarifai.client.input import Inputs
|
|
11
|
+
from clarifai.client.lister import Lister
|
|
12
|
+
from clarifai.constants.search import DEFAULT_SEARCH_METRIC, DEFAULT_TOP_K
|
|
13
|
+
from clarifai.errors import UserError
|
|
14
|
+
from clarifai.schema.search import get_schema
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Search(Lister, BaseClient):
|
|
18
|
+
|
|
19
|
+
def __init__(self,
|
|
20
|
+
user_id,
|
|
21
|
+
app_id,
|
|
22
|
+
top_k: int = DEFAULT_TOP_K,
|
|
23
|
+
metric: str = DEFAULT_SEARCH_METRIC):
|
|
24
|
+
"""Initialize the Search object.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
user_id (str): User ID.
|
|
28
|
+
app_id (str): App ID.
|
|
29
|
+
top_k (int, optional): Top K results to retrieve. Defaults to 10.
|
|
30
|
+
metric (str, optional): Similarity metric (either 'cosine' or 'euclidean'). Defaults to 'cosine'.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
UserError: If the metric is not 'cosine' or 'euclidean'.
|
|
34
|
+
"""
|
|
35
|
+
if metric not in ["cosine", "euclidean"]:
|
|
36
|
+
raise UserError("Metric should be either cosine or euclidean")
|
|
37
|
+
|
|
38
|
+
self.user_id = user_id
|
|
39
|
+
self.app_id = app_id
|
|
40
|
+
self.metric_distance = dict(cosine="COSINE_DISTANCE", euclidean="EUCLIDEAN_DISTANCE")[metric]
|
|
41
|
+
self.data_proto = resources_pb2.Data()
|
|
42
|
+
|
|
43
|
+
self.inputs = Inputs(user_id=self.user_id, app_id=self.app_id)
|
|
44
|
+
self.rank_filter_schema = get_schema()
|
|
45
|
+
BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id)
|
|
46
|
+
Lister.__init__(self, page_size=top_k)
|
|
47
|
+
|
|
48
|
+
def _get_annot_proto(self, **kwargs):
|
|
49
|
+
"""Get an Annotation proto message based on keyword arguments.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
**kwargs: Keyword arguments specifying the resource.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
resources_pb2.Annotation: An Annotation proto message.
|
|
56
|
+
"""
|
|
57
|
+
if not kwargs:
|
|
58
|
+
return resources_pb2.Annotation()
|
|
59
|
+
|
|
60
|
+
self.data_proto = resources_pb2.Data()
|
|
61
|
+
for key, value in kwargs.items():
|
|
62
|
+
if key == "image_bytes":
|
|
63
|
+
image_proto = self.inputs.get_input_from_bytes("", image_bytes=value).data.image
|
|
64
|
+
self.data_proto.image.CopyFrom(image_proto)
|
|
65
|
+
|
|
66
|
+
elif key == "image_url":
|
|
67
|
+
image_proto = self.inputs.get_input_from_url("", image_url=value).data.image
|
|
68
|
+
self.data_proto.image.CopyFrom(image_proto)
|
|
69
|
+
|
|
70
|
+
elif key == "concepts":
|
|
71
|
+
for concept in value:
|
|
72
|
+
concept_proto = resources_pb2.Concept(**concept)
|
|
73
|
+
self.data_proto.concepts.add().CopyFrom(concept_proto)
|
|
74
|
+
|
|
75
|
+
elif key == "text_raw":
|
|
76
|
+
text_proto = self.inputs.get_input_from_bytes(
|
|
77
|
+
"", text_bytes=bytes(value, 'utf-8')).data.text
|
|
78
|
+
self.data_proto.text.CopyFrom(text_proto)
|
|
79
|
+
|
|
80
|
+
elif key == "metadata":
|
|
81
|
+
metadata_struct = Struct()
|
|
82
|
+
metadata_struct.update(value)
|
|
83
|
+
self.data_proto.metadata.CopyFrom(metadata_struct)
|
|
84
|
+
|
|
85
|
+
elif key == "geo_point":
|
|
86
|
+
geo_point_proto = self._get_geo_point_proto(value["longitude"], value["latitude"],
|
|
87
|
+
value["geo_limit"])
|
|
88
|
+
self.data_proto.geo.CopyFrom(geo_point_proto)
|
|
89
|
+
|
|
90
|
+
else:
|
|
91
|
+
raise UserError(f"kwargs contain key that is not supported: {key}")
|
|
92
|
+
return resources_pb2.Annotation(data=self.data_proto)
|
|
93
|
+
|
|
94
|
+
def _get_geo_point_proto(self, longitude: float, latitude: float,
|
|
95
|
+
geo_limit: float) -> resources_pb2.Geo:
|
|
96
|
+
"""Get a GeoPoint proto message based on geographical data.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
longitude (float): Longitude coordinate.
|
|
100
|
+
latitude (float): Latitude coordinate.
|
|
101
|
+
geo_limit (float): Geographical limit.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
resources_pb2.Geo: A Geo proto message.
|
|
105
|
+
"""
|
|
106
|
+
return resources_pb2.Geo(
|
|
107
|
+
geo_point=resources_pb2.GeoPoint(longitude=longitude, latitude=latitude),
|
|
108
|
+
geo_limit=resources_pb2.GeoLimit(type="withinKilometers", value=geo_limit))
|
|
109
|
+
|
|
110
|
+
def list_all_pages_generator(
|
|
111
|
+
self, endpoint: Callable[..., Any], proto_message: Any,
|
|
112
|
+
request_data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
|
|
113
|
+
"""Lists all pages of a resource.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
endpoint (Callable): The endpoint to call.
|
|
117
|
+
proto_message (Any): The proto message to use.
|
|
118
|
+
request_data (dict): The request data to use.
|
|
119
|
+
|
|
120
|
+
Yields:
|
|
121
|
+
response_dict: The next item in the listing.
|
|
122
|
+
"""
|
|
123
|
+
page = 1
|
|
124
|
+
request_data['pagination'] = service_pb2.Pagination(page=page, per_page=self.default_page_size)
|
|
125
|
+
while True:
|
|
126
|
+
request_data['pagination'].page = page
|
|
127
|
+
response = self._grpc_request(endpoint, proto_message(**request_data))
|
|
128
|
+
dict_response = MessageToDict(response, preserving_proto_field_name=True)
|
|
129
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
|
130
|
+
raise Exception(f"Listing failed with response {response!r}")
|
|
131
|
+
|
|
132
|
+
if 'hits' not in list(dict_response.keys()):
|
|
133
|
+
break
|
|
134
|
+
page += 1
|
|
135
|
+
yield response
|
|
136
|
+
|
|
137
|
+
def query(self, ranks=[{}], filters=[{}]):
|
|
138
|
+
"""Perform a query with rank and filters.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
ranks (List[Dict], optional): List of rank parameters. Defaults to [{}].
|
|
142
|
+
filters (List[Dict], optional): List of filter parameters. Defaults to [{}].
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Generator[Dict[str, Any], None, None]: A generator of query results.
|
|
146
|
+
"""
|
|
147
|
+
try:
|
|
148
|
+
self.rank_filter_schema.validate(ranks)
|
|
149
|
+
self.rank_filter_schema.validate(filters)
|
|
150
|
+
except SchemaError as err:
|
|
151
|
+
raise UserError(f"Invalid rank or filter input: {err}")
|
|
152
|
+
|
|
153
|
+
rank_annot_proto, filters_annot_proto = [], []
|
|
154
|
+
for rank_dict in ranks:
|
|
155
|
+
rank_annot_proto.append(self._get_annot_proto(**rank_dict))
|
|
156
|
+
for filter_dict in filters:
|
|
157
|
+
filters_annot_proto.append(self._get_annot_proto(**filter_dict))
|
|
158
|
+
|
|
159
|
+
all_ranks = [resources_pb2.Rank(annotation=rank_annot) for rank_annot in rank_annot_proto]
|
|
160
|
+
all_filters = [
|
|
161
|
+
resources_pb2.Filter(annotation=filter_annot) for filter_annot in filters_annot_proto
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
request_data = dict(
|
|
165
|
+
user_app_id=self.user_app_id,
|
|
166
|
+
searches=[
|
|
167
|
+
resources_pb2.Search(
|
|
168
|
+
query=resources_pb2.Query(ranks=all_ranks, filters=all_filters),
|
|
169
|
+
metric=self.metric_distance)
|
|
170
|
+
])
|
|
171
|
+
|
|
172
|
+
return self.list_all_pages_generator(self.STUB.PostAnnotationsSearches,
|
|
173
|
+
service_pb2.PostAnnotationsSearchesRequest, request_data)
|
clarifai/client/workflow.py
CHANGED
|
@@ -194,7 +194,7 @@ class Workflow(Lister, BaseClient):
|
|
|
194
194
|
Example:
|
|
195
195
|
>>> from clarifai.client.workflow import Workflow
|
|
196
196
|
>>> workflow = Workflow("https://clarifai.com/clarifai/main/workflows/Demographics")
|
|
197
|
-
>>> workflow.export('out_path')
|
|
197
|
+
>>> workflow.export('out_path.yml')
|
|
198
198
|
"""
|
|
199
199
|
request = service_pb2.GetWorkflowRequest(user_app_id=self.user_app_id, workflow_id=self.id)
|
|
200
200
|
response = self._grpc_request(self.STUB.GetWorkflow, request)
|
|
@@ -16,10 +16,10 @@ $ clarifai-model-upload-init --model_name <Your model name> \
|
|
|
16
16
|
3. Add your model loading and inference code inside `inference.py` script of the generated model repository under the `setup()` and `predict()` functions respectively. Refer to The [Inference Script section]() for a description of this file.
|
|
17
17
|
4. Testing your implementation locally by running `<your_triton_folder>/1/test.py` with basic predefined tests.
|
|
18
18
|
To avoid missing dependencies when deploying, recommend to use conda to create clean environment from [Clarifai base envs](./envs/). Then install everything in `requirements.txt`. Follow instruction inside [test.py](./models/test.py) for implementing custom tests.
|
|
19
|
-
* Create conda env
|
|
19
|
+
* Create conda env and install requirements:
|
|
20
20
|
```bash
|
|
21
|
-
# create env
|
|
22
|
-
conda
|
|
21
|
+
# create env (note: only python version 3.8 is supported currently)
|
|
22
|
+
conda create -n <your_env> python=3.8
|
|
23
23
|
# activate it
|
|
24
24
|
conda activate <your_env>
|
|
25
25
|
# install dependencies
|
|
@@ -3,14 +3,9 @@
|
|
|
3
3
|
Each model built for inference with triton requires certain dependencies & dependency versions be installed for successful inference execution.
|
|
4
4
|
An execution environment is created for each model to be deployed on Clarifai and all necessary dependencies as listed in the `requirements.txt` file are installed there.
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
## Supported python and torch versions
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
```
|
|
13
|
-
All dependencies in this environment can be [found here](../envs/triton_conda.yaml).
|
|
14
|
-
|
|
15
|
-
By default all `triton_conda.yaml` files in the generated model repository use the environment above as its currently the only one available.
|
|
16
|
-
Dependencies specified in the `requirements.txt` file are prioritized in case there's a difference in versions with those pre-installed in the base pre-configured environment.
|
|
8
|
+
Currently, models must use python 3.8 (any 3.8.x). Supported torch versions are 1.13.1 and 2.0.1.
|
|
9
|
+
If your model depends on torch, torch must be listed in your requirements.txt file (even if it is
|
|
10
|
+
already a dependency of another package). An appropriate supported torch version will be selected
|
|
11
|
+
based on your requirements.txt.
|
|
@@ -78,11 +78,9 @@ class TritonModelRepository:
|
|
|
78
78
|
pass
|
|
79
79
|
else:
|
|
80
80
|
continue
|
|
81
|
-
# gen requirements
|
|
81
|
+
# gen requirements
|
|
82
82
|
with open(os.path.join(repository_path, "requirements.txt"), "w") as f:
|
|
83
83
|
f.write("clarifai>9.5.3\ntritonclient[all]") # for model upload utils
|
|
84
|
-
with open(os.path.join(repository_path, "triton_conda.yaml"), "w") as conda_env:
|
|
85
|
-
conda_env.write("name: triton_conda-cp3.8-torch1.13.1-19f97078")
|
|
86
84
|
|
|
87
85
|
if not os.path.isdir(model_version_path):
|
|
88
86
|
os.mkdir(model_version_path)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from schema import And, Optional, Regex, Schema
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_schema() -> Schema:
|
|
5
|
+
"""Initialize the schema for rank and filter.
|
|
6
|
+
|
|
7
|
+
This schema validates:
|
|
8
|
+
|
|
9
|
+
- Rank and filter must be a list
|
|
10
|
+
- Each item in the list must be a dict
|
|
11
|
+
- The dict can contain these optional keys:
|
|
12
|
+
- 'image_url': Valid URL string
|
|
13
|
+
- 'text_raw': Non-empty string
|
|
14
|
+
- 'metadata': Dict
|
|
15
|
+
- 'image_bytes': Bytes
|
|
16
|
+
- 'geo_point': Dict with 'longitude', 'latitude' and 'geo_limit' as float, float and int respectively
|
|
17
|
+
- 'concepts': List where each item is a concept dict
|
|
18
|
+
- Concept dict requires at least one of:
|
|
19
|
+
- 'name': Non-empty string with dashes/underscores
|
|
20
|
+
- 'id': Non-empty string
|
|
21
|
+
- 'language': Non-empty string
|
|
22
|
+
- 'value': 0 or 1 integer
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Schema: The schema for rank and filter.
|
|
26
|
+
"""
|
|
27
|
+
# Schema for a single concept
|
|
28
|
+
concept_schema = Schema({
|
|
29
|
+
Optional('value'):
|
|
30
|
+
And(int, lambda x: x in [0, 1]),
|
|
31
|
+
Optional('id'):
|
|
32
|
+
And(str, len),
|
|
33
|
+
Optional('language'):
|
|
34
|
+
And(str, len),
|
|
35
|
+
# Non-empty strings with internal dashes and underscores.
|
|
36
|
+
Optional('name'):
|
|
37
|
+
And(str, len, Regex(r'^[0-9A-Za-z]+([-_][0-9A-Za-z]+)*$'))
|
|
38
|
+
})
|
|
39
|
+
|
|
40
|
+
# Schema for a rank or filter item
|
|
41
|
+
rank_filter_item_schema = Schema({
|
|
42
|
+
Optional('image_url'):
|
|
43
|
+
And(str, Regex(r'^https?://')),
|
|
44
|
+
Optional('text_raw'):
|
|
45
|
+
And(str, len),
|
|
46
|
+
Optional('metadata'):
|
|
47
|
+
dict,
|
|
48
|
+
Optional('image_bytes'):
|
|
49
|
+
bytes,
|
|
50
|
+
Optional('geo_point'): {
|
|
51
|
+
'longitude': float,
|
|
52
|
+
'latitude': float,
|
|
53
|
+
'geo_limit': int
|
|
54
|
+
},
|
|
55
|
+
Optional("concepts"):
|
|
56
|
+
And(list, lambda x: all(concept_schema.is_valid(item) and len(item) > 0 for item in x)),
|
|
57
|
+
})
|
|
58
|
+
|
|
59
|
+
# Schema for rank and filter args
|
|
60
|
+
return Schema([rank_filter_item_schema])
|
clarifai/versions.py
CHANGED
clarifai/workflows/export.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
|
-
import
|
|
1
|
+
from typing import Any, Dict
|
|
2
2
|
|
|
3
|
+
import yaml
|
|
3
4
|
from google.protobuf.json_format import MessageToDict
|
|
4
5
|
|
|
5
6
|
VALID_YAML_KEYS = ["workflow", "id", "nodes", "node_inputs", "node_id", "model"]
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def clean_up_unused_keys(wf: dict):
|
|
9
|
-
"""Removes unused keys from dict before exporting to yaml.
|
|
10
|
-
Supports nested dicts.
|
|
11
|
-
"""
|
|
10
|
+
"""Removes unused keys from dict before exporting to yaml. Supports nested dicts."""
|
|
12
11
|
new_wf = dict()
|
|
13
12
|
for key, val in wf.items():
|
|
14
13
|
if key not in VALID_YAML_KEYS:
|
|
@@ -44,10 +43,12 @@ class Exporter:
|
|
|
44
43
|
def __enter__(self):
|
|
45
44
|
return self
|
|
46
45
|
|
|
47
|
-
def parse(self):
|
|
46
|
+
def parse(self) -> Dict[str, Any]:
|
|
48
47
|
"""Reads a resources_pb2.Workflow object (e.g. from a GetWorkflow response)
|
|
49
|
-
|
|
50
|
-
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
dict: A dict representation of the workflow.
|
|
51
|
+
"""
|
|
51
52
|
if isinstance(self.wf, list):
|
|
52
53
|
self.wf = self.wf[0]
|
|
53
54
|
wf = {"workflow": MessageToDict(self.wf, preserving_proto_field_name=True)}
|
|
@@ -57,7 +58,7 @@ class Exporter:
|
|
|
57
58
|
|
|
58
59
|
def export(self, out_path):
|
|
59
60
|
with open(out_path, 'w') as out_file:
|
|
60
|
-
yaml.dump(self.wf_dict["workflow"], out_file, default_flow_style=
|
|
61
|
+
yaml.dump(self.wf_dict["workflow"], out_file, default_flow_style=False)
|
|
61
62
|
|
|
62
63
|
def __exit__(self, *args):
|
|
63
64
|
self.close()
|
clarifai/workflows/utils.py
CHANGED
|
@@ -31,7 +31,7 @@ def is_same_yaml_model(api_model: resources_pb2.Model, yaml_model: Dict) -> bool
|
|
|
31
31
|
|
|
32
32
|
yaml_model_from_api = dict()
|
|
33
33
|
for k, _ in yaml_model.items():
|
|
34
|
-
if k == "output_info":
|
|
34
|
+
if k == "output_info" and api_model["model_version"].get("output_info", "") != "":
|
|
35
35
|
yaml_model_from_api[k] = dict(params=api_model["model_version"]["output_info"].get("params"))
|
|
36
36
|
else:
|
|
37
37
|
yaml_model_from_api[k] = api_model.get(k)
|
clarifai/workflows/validate.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from schema import And, Optional, Regex, Schema, SchemaError, Use
|
|
2
2
|
|
|
3
3
|
# Non-empty, up to 32-character ASCII strings with internal dashes and underscores.
|
|
4
|
-
_id_validator = And(str, lambda s: 0 < len(s) <=
|
|
4
|
+
_id_validator = And(str, lambda s: 0 < len(s) <= 48, Regex(r'^[0-9A-Za-z]+([-_][0-9A-Za-z]+)*$'))
|
|
5
5
|
|
|
6
6
|
# 32-character hex string, converted to lower-case.
|
|
7
7
|
_hex_id_validator = And(str, Use(str.lower), Regex(r'^[0-9a-f]{32}'))
|