unitlab 2.0.1__tar.gz → 2.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: unitlab
3
- Version: 2.0.1
3
+ Version: 2.0.3
4
4
  Home-page: https://github.com/teamunitlab/unitlab-sdk
5
5
  Author: Unitlab Inc.
6
6
  Author-email: team@unitlab.ai
@@ -2,7 +2,7 @@ from setuptools import find_packages, setup
2
2
 
3
3
  setup(
4
4
  name="unitlab",
5
- version="2.0.1",
5
+ version="2.0.3",
6
6
  license="MIT",
7
7
  author="Unitlab Inc.",
8
8
  author_email="team@unitlab.ai",
@@ -8,8 +8,11 @@ import aiohttp
8
8
  import requests
9
9
  import tqdm
10
10
 
11
+ from .dataset import DatasetUploadHandler
11
12
  from .exceptions import AuthenticationError
12
- from .utils import ENDPOINTS, send_request
13
+ from .utils import BASE_URL, ENDPOINTS, send_request
14
+
15
+ logger = logging.getLogger(__name__)
13
16
 
14
17
 
15
18
  class UnitlabClient:
@@ -52,7 +55,7 @@ class UnitlabClient:
52
55
  raise AuthenticationError(
53
56
  message="Please provide the api_key argument or set UNITLAB_API_KEY in your environment."
54
57
  )
55
- logging.info("Found a Unitlab API key in your environment.")
58
+ logger.info("Found a Unitlab API key in your environment.")
56
59
  self.api_key = api_key
57
60
  self.api_session = requests.Session()
58
61
  adapter = requests.adapters.HTTPAdapter(max_retries=3)
@@ -157,11 +160,11 @@ class UnitlabClient:
157
160
  response.raise_for_status()
158
161
  return 1 if response.status == 201 else 0
159
162
  except aiohttp.client_exceptions.ServerDisconnectedError as e:
160
- logging.warning(f"Error: {e}: Retrying...")
163
+ logger.warning(f"Error: {e}: Retrying...")
161
164
  await asyncio.sleep(0.1)
162
165
  continue
163
166
  except Exception as e:
164
- logging.error(f"Error uploading file {file} - {e}")
167
+ logger.error(f"Error uploading file {file} - {e}")
165
168
  return 0
166
169
 
167
170
  async def batch_upload(
@@ -191,7 +194,7 @@ class UnitlabClient:
191
194
  for file in files:
192
195
  file_size = os.path.getsize(file) / 1024 / 1024
193
196
  if file_size > 6:
194
- logging.warning(
197
+ logger.warning(
195
198
  f"File {file} is too large ({file_size:.4f} megabytes) skipping, max size is 6 MB"
196
199
  )
197
200
  continue
@@ -200,7 +203,7 @@ class UnitlabClient:
200
203
  num_files = len(filtered_files)
201
204
  num_batches = (num_files + batch_size - 1) // batch_size
202
205
 
203
- logging.info(f"Uploading {num_files} files to project {project_id}")
206
+ logger.info(f"Uploading {num_files} files to project {project_id}")
204
207
  with tqdm.tqdm(total=num_files, ncols=80) as pbar:
205
208
  async with aiohttp.ClientSession(
206
209
  headers=self._get_headers()
@@ -244,7 +247,7 @@ class UnitlabClient:
244
247
  with open(filename, "wb") as f:
245
248
  for chunk in r.iter_content(chunk_size=1024 * 1024):
246
249
  f.write(chunk)
247
- logging.info(f"File: {os.path.abspath(filename)}")
250
+ logger.info(f"File: {os.path.abspath(filename)}")
248
251
  return os.path.abspath(filename)
249
252
 
250
253
  def download_dataset_files(self, dataset_id):
@@ -270,7 +273,7 @@ class UnitlabClient:
270
273
  try:
271
274
  r.raise_for_status()
272
275
  except Exception as e:
273
- logging.error(
276
+ logger.error(
274
277
  f"Error downloading file {dataset_file['file_name']} - {e}"
275
278
  )
276
279
  return 0
@@ -292,3 +295,51 @@ class UnitlabClient:
292
295
  pbar.update(await f)
293
296
 
294
297
  asyncio.run(main())
298
+
299
+ def create_dataset(self, name, annotation_type, categories):
300
+ response = self.api_session.post(
301
+ url=f"{BASE_URL}/api/sdk/datasets/create/",
302
+ headers=self._get_headers(),
303
+ json={
304
+ "name": name,
305
+ "annotation_type": annotation_type,
306
+ "classes": [
307
+ {"name": category["name"], "value": category["id"]}
308
+ for category in categories
309
+ ],
310
+ },
311
+ )
312
+ response.raise_for_status()
313
+ response = response.json()
314
+ return response["pk"]
315
+
316
+ def dataset_upload(
317
+ self, name, annotation_type, annotation_path, data_path, batch_size=100
318
+ ):
319
+ import random
320
+
321
+ handler = DatasetUploadHandler(annotation_type, annotation_path, data_path)
322
+ dataset_id = self.create_dataset(name, annotation_type, handler.categories)
323
+ img_ids = handler.getImgIds()
324
+ random.shuffle(img_ids)
325
+ image_ids = img_ids[:1000]
326
+ num_batches = (len(image_ids) + batch_size - 1) // batch_size
327
+
328
+ async def main():
329
+ with tqdm.tqdm(total=len(image_ids), ncols=80) as pbar:
330
+ async with aiohttp.ClientSession(
331
+ headers=self._get_headers()
332
+ ) as session:
333
+ for i in range(num_batches):
334
+ tasks = []
335
+ for image_id in image_ids[
336
+ i * batch_size : min((i + 1) * batch_size, len(image_ids))
337
+ ]:
338
+ tasks.append(
339
+ handler.upload_image(session, dataset_id, image_id)
340
+ )
341
+ for f in asyncio.as_completed(tasks):
342
+ pbar.update(await f)
343
+
344
+ asyncio.run(main())
345
+ self.dataset_download(dataset_id, "COCO")
@@ -0,0 +1,303 @@
1
+ import itertools
2
+ import json
3
+ import logging
4
+ import os
5
+ from collections import defaultdict
6
+
7
+ import aiofiles
8
+ import aiohttp
9
+
10
+ from .utils import BASE_URL
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class COCO:
16
+ def __init__(self, annotation_type, annotation_path, data_path):
17
+ """
18
+ :param annotation_type (str): one of ['img_bbox', 'img_semantic_segmentation', 'img_polygon', 'img_keypoints']
19
+ :param annotation_path (str): location of annotation file
20
+ :param data_path (str): directory containing the images
21
+ :return:
22
+ """
23
+ self.annotation_type = annotation_type
24
+ self.annotation_path = annotation_path
25
+ self.data_path = data_path
26
+ self.anns, self.cats, self.imgs = dict(), dict(), dict()
27
+ self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
28
+ self._load_dataset()
29
+
30
+ @staticmethod
31
+ def _is_array_like(obj):
32
+ return hasattr(obj, "__iter__") and hasattr(obj, "__len__")
33
+
34
+ def _validate(self):
35
+ if not os.path.isdir(self.data_path):
36
+ raise ValueError(
37
+ "Data path '{}' does not exist or is not a directory".format(
38
+ self.data_path
39
+ )
40
+ )
41
+ if self.annotation_type not in [
42
+ "img_bbox",
43
+ "img_semantic_segmentation",
44
+ "img_polygon",
45
+ "img_keypoints",
46
+ ]:
47
+ raise ValueError(
48
+ "Invalid annotation type '{}'. Supported types are: ['img_bbox', 'img_semantic_segmentation', 'img_polygon', 'img_keypoints']".format(
49
+ self.annotation_type
50
+ )
51
+ )
52
+ for required_key in ["images", "annotations", "categories"]:
53
+ if required_key not in self.dataset.keys():
54
+ raise KeyError(
55
+ "Required key '{}' not found in the COCO dataset".format(
56
+ required_key
57
+ )
58
+ )
59
+ if len(self.dataset[required_key]) == 0:
60
+ raise ValueError(
61
+ "Required key '{}' does not contain values".format(required_key)
62
+ )
63
+
64
+ def _load_dataset(self):
65
+ with open(self.annotation_path, "r") as f:
66
+ self.dataset = json.load(f)
67
+ self._validate()
68
+ self.createIndex()
69
+
70
+ def createIndex(self):
71
+ anns, cats, imgs = {}, {}, {}
72
+ imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
73
+ for ann in self.dataset["annotations"]:
74
+ imgToAnns[ann["image_id"]].append(ann)
75
+ anns[ann["id"]] = ann
76
+
77
+ for img in self.dataset["images"]:
78
+ imgs[img["id"]] = img
79
+
80
+ for cat in self.dataset["categories"]:
81
+ cats[cat["id"]] = cat
82
+
83
+ for ann in self.dataset["annotations"]:
84
+ catToImgs[ann["category_id"]].append(ann["image_id"])
85
+
86
+ # create class members
87
+ self.anns = anns
88
+ self.imgToAnns = imgToAnns
89
+ self.catToImgs = catToImgs
90
+ self.imgs = imgs
91
+ self.cats = cats
92
+ self.categories = sorted(self.loadCats(self.getCatIds()), key=lambda x: x["id"])
93
+ self.classes = [cat["name"] for cat in self.categories]
94
+ self.original_category_referecences = dict()
95
+ for i, category in enumerate(self.categories):
96
+ self.original_category_referecences[category["id"]] = i
97
+ category["id"] = i
98
+
99
+ def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
100
+ """
101
+ Get ann ids that satisfy given filter conditions. default skips that filter
102
+ :param imgIds (int array) : get anns for given imgs
103
+ catIds (int array) : get anns for given cats
104
+ areaRng (float array) : get anns for given area range (e.g. [0 inf])
105
+ iscrowd (boolean) : get anns for given crowd label (False or True)
106
+ :return: ids (int array) : integer array of ann ids
107
+ """
108
+ imgIds = imgIds if self._is_array_like(imgIds) else [imgIds]
109
+ catIds = catIds if self._is_array_like(catIds) else [catIds]
110
+
111
+ if len(imgIds) == len(catIds) == len(areaRng) == 0:
112
+ anns = self.dataset["annotations"]
113
+ else:
114
+ if not len(imgIds) == 0:
115
+ lists = [
116
+ self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns
117
+ ]
118
+ anns = list(itertools.chain.from_iterable(lists))
119
+ else:
120
+ anns = self.dataset["annotations"]
121
+ anns = (
122
+ anns
123
+ if len(catIds) == 0
124
+ else [ann for ann in anns if ann["category_id"] in catIds]
125
+ )
126
+ anns = (
127
+ anns
128
+ if len(areaRng) == 0
129
+ else [
130
+ ann
131
+ for ann in anns
132
+ if ann["area"] > areaRng[0] and ann["area"] < areaRng[1]
133
+ ]
134
+ )
135
+ if not iscrowd == None:
136
+ ids = [ann["id"] for ann in anns if ann["iscrowd"] == iscrowd]
137
+ else:
138
+ ids = [ann["id"] for ann in anns]
139
+ return ids
140
+
141
+ def getCatIds(self, catNms=[], supNms=[], catIds=[]):
142
+ """
143
+ filtering parameters. default skips that filter.
144
+ :param catNms (str array) : get cats for given cat names
145
+ :param supNms (str array) : get cats for given supercategory names
146
+ :param catIds (int array) : get cats for given cat ids
147
+ :return: ids (int array) : integer array of cat ids
148
+ """
149
+ catNms = catNms if self._is_array_like(catNms) else [catNms]
150
+ supNms = supNms if self._is_array_like(supNms) else [supNms]
151
+ catIds = catIds if self._is_array_like(catIds) else [catIds]
152
+
153
+ if len(catNms) == len(supNms) == len(catIds) == 0:
154
+ cats = self.dataset["categories"]
155
+ else:
156
+ cats = self.dataset["categories"]
157
+ cats = (
158
+ cats
159
+ if len(catNms) == 0
160
+ else [cat for cat in cats if cat["name"] in catNms]
161
+ )
162
+ cats = (
163
+ cats
164
+ if len(supNms) == 0
165
+ else [cat for cat in cats if cat["supercategory"] in supNms]
166
+ )
167
+ cats = (
168
+ cats
169
+ if len(catIds) == 0
170
+ else [cat for cat in cats if cat["id"] in catIds]
171
+ )
172
+ ids = [cat["id"] for cat in cats]
173
+ return ids
174
+
175
+ def getImgIds(self, imgIds=[], catIds=[]):
176
+ """
177
+ Get img ids that satisfy given filter conditions.
178
+ :param imgIds (int array) : get imgs for given ids
179
+ :param catIds (int array) : get imgs with all given cats
180
+ :return: ids (int array) : integer array of img ids
181
+ """
182
+ imgIds = imgIds if self._is_array_like(imgIds) else [imgIds]
183
+ catIds = catIds if self._is_array_like(catIds) else [catIds]
184
+
185
+ if len(imgIds) == len(catIds) == 0:
186
+ ids = self.imgs.keys()
187
+ else:
188
+ ids = set(imgIds)
189
+ for i, catId in enumerate(catIds):
190
+ if i == 0 and len(ids) == 0:
191
+ ids = set(self.catToImgs[catId])
192
+ else:
193
+ ids &= set(self.catToImgs[catId])
194
+ return list(ids)
195
+
196
+ def loadAnns(self, ids=[]):
197
+ """
198
+ Load anns with the specified ids.
199
+ :param ids (int array) : integer ids specifying anns
200
+ :return: anns (object array) : loaded ann objects
201
+ """
202
+ if self._is_array_like(ids):
203
+ return [self.anns[id] for id in ids]
204
+ elif type(ids) == int:
205
+ return [self.anns[ids]]
206
+
207
+ def loadCats(self, ids=[]):
208
+ """
209
+ Load cats with the specified ids.
210
+ :param ids (int array) : integer ids specifying cats
211
+ :return: cats (object array) : loaded cat objects
212
+ """
213
+ if self._is_array_like(ids):
214
+ return [self.cats[id] for id in ids]
215
+ elif type(ids) == int:
216
+ return [self.cats[ids]]
217
+
218
+ def loadImgs(self, ids=[]):
219
+ """
220
+ Load anns with the specified ids.
221
+ :param ids (int array) : integer ids specifying img
222
+ :return: imgs (object array) : loaded img objects
223
+ """
224
+ if self._is_array_like(ids):
225
+ return [self.imgs[id] for id in ids]
226
+ elif type(ids) == int:
227
+ return [self.imgs[ids]]
228
+
229
+
230
+ class DatasetUploadHandler(COCO):
231
+ def get_img_bbox_payload(self, anns):
232
+ predicted_classes = set()
233
+ bboxes = []
234
+ for ann in anns:
235
+ bbox = ann["bbox"]
236
+ bboxes.append(
237
+ {
238
+ "point": [
239
+ [bbox[0], bbox[1]],
240
+ [bbox[0] + bbox[2], bbox[1]],
241
+ [bbox[0] + bbox[2], bbox[1] + bbox[3]],
242
+ [bbox[0], bbox[1] + bbox[3]],
243
+ ],
244
+ "class": self.original_category_referecences.get(
245
+ ann["category_id"]
246
+ ),
247
+ "recognition": ann.get("recognition", ""),
248
+ }
249
+ )
250
+ predicted_classes.add(
251
+ self.original_category_referecences.get(ann["category_id"])
252
+ )
253
+ payload = {"bboxes": [bboxes]}
254
+ payload["predicted_classes"] = list(predicted_classes)
255
+ payload["classes"] = self.classes
256
+ return json.dumps(payload)
257
+
258
+ def get_img_polygon_payload(self, anns):
259
+ logger.warning("Not implemented yet")
260
+
261
+ def get_img_semantic_segmentation_payload(self, anns):
262
+ logger.warning("Not implemented yet")
263
+
264
+ def get_img_skeleton_payload(self, anns):
265
+ logger.warning("Not implemented yet")
266
+
267
+ def get_payload(self, img_id):
268
+ image = self.imgs[img_id]
269
+ ann_ids = self.getAnnIds(imgIds=img_id)
270
+ anns = self.loadAnns(ann_ids)
271
+ if not os.path.isfile(os.path.join(self.data_path, image["file_name"])):
272
+ logger.warning(
273
+ "Image file not found: {}".format(
274
+ os.path.join(self.data_path, image["file_name"])
275
+ )
276
+ )
277
+ return
278
+ if len(anns) == 0:
279
+ logger.warning("No annotations found for image: {}".format(img_id))
280
+ return
281
+ return getattr(self, f"get_{self.annotation_type}_payload")(anns)
282
+
283
+ async def upload_image(self, session, dataset_id, image_id):
284
+ image = self.loadImgs(image_id)[0]
285
+ file_name = image["file_name"]
286
+ payload = self.get_payload(image_id)
287
+ if payload:
288
+ try:
289
+ async with aiofiles.open(
290
+ os.path.join(self.data_path, file_name), "rb"
291
+ ) as f:
292
+ form_data = aiohttp.FormData()
293
+ form_data.add_field("file", await f.read(), filename=file_name)
294
+ form_data.add_field("result", self.get_payload(image_id))
295
+ async with session.post(
296
+ f"{BASE_URL}/api/sdk/datasets/{dataset_id}/upload/",
297
+ data=form_data,
298
+ ) as response:
299
+ response.raise_for_status()
300
+ return 1
301
+ except Exception as e:
302
+ logger.error(f"Error uploading file {file_name} - {e}")
303
+ return 0
@@ -24,6 +24,13 @@ class DownloadType(str, Enum):
24
24
  files = "files"
25
25
 
26
26
 
27
+ class AnnotationType(str, Enum):
28
+ img_bbox = "img_bbox"
29
+ img_polygon = "img_polygon"
30
+ img_semantic_segmentation = "img_semantic_segmentation"
31
+ img_skeleton = "img_skeleton"
32
+
33
+
27
34
  def get_client(api_key: str) -> UnitlabClient:
28
35
  return UnitlabClient(api_key=api_key)
29
36
 
@@ -97,9 +104,32 @@ def dataset_list(
97
104
  print(response.json())
98
105
 
99
106
 
107
+ @dataset_app.command(name="upload", help="Upload dataset")
108
+ def dataset_upload(
109
+ api_key: API_KEY,
110
+ name: Annotated[str, typer.Option(help="Name of the dataset")],
111
+ annotation_type: Annotated[
112
+ AnnotationType,
113
+ typer.Option(
114
+ help="Annotation type (img_bbox, img_polygon, img_semantic_segmentation, img_skeleton)"
115
+ ),
116
+ ],
117
+ annotation_path: Annotated[Path, typer.Option(help="Path to the COCO json file")],
118
+ data_path: Annotated[
119
+ Path, typer.Option(help="Directory containing the data to be uploaded")
120
+ ],
121
+ batch_size: Annotated[
122
+ int, typer.Option(help="Batch size for uploading images")
123
+ ] = 100,
124
+ ):
125
+ get_client(api_key).dataset_upload(
126
+ name, annotation_type, annotation_path, data_path, batch_size
127
+ )
128
+
129
+
100
130
  @dataset_app.command(name="download", help="Download dataset")
101
131
  def dataset_download(
102
- pk: Annotated[Optional[UUID], typer.Argument()],
132
+ pk: UUID,
103
133
  api_key: API_KEY,
104
134
  download_type: Annotated[
105
135
  DownloadType,
@@ -17,6 +17,7 @@ ENDPOINTS = {
17
17
  "cli_project_members": "/api/cli/projects/{}/members/",
18
18
  "cli_datasets": "/api/cli/datasets/",
19
19
  }
20
+ BASE_URL = os.environ.get("UNITLAB_BASE_URL", "https://api.unitlab.ai")
20
21
 
21
22
 
22
23
  def send_request(request, session=None):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: unitlab
3
- Version: 2.0.1
3
+ Version: 2.0.3
4
4
  Home-page: https://github.com/teamunitlab/unitlab-sdk
5
5
  Author: Unitlab Inc.
6
6
  Author-email: team@unitlab.ai
@@ -5,6 +5,7 @@ setup.py
5
5
  src/unitlab/__init__.py
6
6
  src/unitlab/__main__.py
7
7
  src/unitlab/client.py
8
+ src/unitlab/dataset.py
8
9
  src/unitlab/exceptions.py
9
10
  src/unitlab/main.py
10
11
  src/unitlab/utils.py
File without changes
File without changes
File without changes
File without changes
File without changes