pixeltable 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (110) hide show
  1. pixeltable/__init__.py +20 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +23 -7
  4. pixeltable/catalog/insertable_table.py +32 -19
  5. pixeltable/catalog/table.py +210 -20
  6. pixeltable/catalog/table_version.py +272 -111
  7. pixeltable/catalog/table_version_path.py +6 -1
  8. pixeltable/dataframe.py +184 -110
  9. pixeltable/datatransfer/__init__.py +1 -0
  10. pixeltable/datatransfer/label_studio.py +526 -0
  11. pixeltable/datatransfer/remote.py +113 -0
  12. pixeltable/env.py +213 -79
  13. pixeltable/exec/__init__.py +2 -1
  14. pixeltable/exec/data_row_batch.py +6 -7
  15. pixeltable/exec/expr_eval_node.py +28 -28
  16. pixeltable/exec/sql_scan_node.py +7 -6
  17. pixeltable/exprs/__init__.py +4 -3
  18. pixeltable/exprs/column_ref.py +11 -2
  19. pixeltable/exprs/comparison.py +39 -1
  20. pixeltable/exprs/data_row.py +7 -0
  21. pixeltable/exprs/expr.py +26 -19
  22. pixeltable/exprs/function_call.py +17 -18
  23. pixeltable/exprs/globals.py +14 -2
  24. pixeltable/exprs/image_member_access.py +9 -28
  25. pixeltable/exprs/in_predicate.py +96 -0
  26. pixeltable/exprs/inline_array.py +13 -11
  27. pixeltable/exprs/inline_dict.py +15 -13
  28. pixeltable/exprs/row_builder.py +7 -1
  29. pixeltable/exprs/similarity_expr.py +67 -0
  30. pixeltable/ext/functions/whisperx.py +30 -0
  31. pixeltable/ext/functions/yolox.py +16 -0
  32. pixeltable/func/__init__.py +0 -2
  33. pixeltable/func/aggregate_function.py +5 -2
  34. pixeltable/func/callable_function.py +57 -13
  35. pixeltable/func/expr_template_function.py +14 -3
  36. pixeltable/func/function.py +35 -4
  37. pixeltable/func/signature.py +5 -15
  38. pixeltable/func/udf.py +8 -12
  39. pixeltable/functions/fireworks.py +9 -4
  40. pixeltable/functions/huggingface.py +48 -5
  41. pixeltable/functions/openai.py +49 -11
  42. pixeltable/functions/pil/image.py +61 -64
  43. pixeltable/functions/together.py +32 -6
  44. pixeltable/functions/util.py +0 -43
  45. pixeltable/functions/video.py +46 -8
  46. pixeltable/globals.py +443 -0
  47. pixeltable/index/__init__.py +1 -0
  48. pixeltable/index/base.py +9 -2
  49. pixeltable/index/btree.py +54 -0
  50. pixeltable/index/embedding_index.py +91 -15
  51. pixeltable/io/__init__.py +4 -0
  52. pixeltable/io/globals.py +59 -0
  53. pixeltable/{utils → io}/hf_datasets.py +48 -17
  54. pixeltable/io/pandas.py +148 -0
  55. pixeltable/{utils → io}/parquet.py +58 -33
  56. pixeltable/iterators/__init__.py +1 -1
  57. pixeltable/iterators/base.py +8 -4
  58. pixeltable/iterators/document.py +225 -93
  59. pixeltable/iterators/video.py +16 -9
  60. pixeltable/metadata/__init__.py +8 -4
  61. pixeltable/metadata/converters/convert_12.py +3 -0
  62. pixeltable/metadata/converters/convert_13.py +41 -0
  63. pixeltable/metadata/converters/convert_14.py +13 -0
  64. pixeltable/metadata/converters/convert_15.py +29 -0
  65. pixeltable/metadata/converters/util.py +63 -0
  66. pixeltable/metadata/schema.py +12 -6
  67. pixeltable/plan.py +11 -24
  68. pixeltable/store.py +16 -23
  69. pixeltable/tool/create_test_db_dump.py +49 -14
  70. pixeltable/type_system.py +27 -58
  71. pixeltable/utils/coco.py +94 -0
  72. pixeltable/utils/documents.py +42 -12
  73. pixeltable/utils/http_server.py +70 -0
  74. pixeltable-0.2.7.dist-info/METADATA +137 -0
  75. pixeltable-0.2.7.dist-info/RECORD +126 -0
  76. {pixeltable-0.2.5.dist-info → pixeltable-0.2.7.dist-info}/WHEEL +1 -1
  77. pixeltable/client.py +0 -600
  78. pixeltable/exprs/image_similarity_predicate.py +0 -58
  79. pixeltable/func/batched_function.py +0 -53
  80. pixeltable/func/nos_function.py +0 -202
  81. pixeltable/tests/conftest.py +0 -171
  82. pixeltable/tests/ext/test_yolox.py +0 -21
  83. pixeltable/tests/functions/test_fireworks.py +0 -43
  84. pixeltable/tests/functions/test_functions.py +0 -60
  85. pixeltable/tests/functions/test_huggingface.py +0 -158
  86. pixeltable/tests/functions/test_openai.py +0 -162
  87. pixeltable/tests/functions/test_together.py +0 -112
  88. pixeltable/tests/test_audio.py +0 -65
  89. pixeltable/tests/test_catalog.py +0 -27
  90. pixeltable/tests/test_client.py +0 -21
  91. pixeltable/tests/test_component_view.py +0 -379
  92. pixeltable/tests/test_dataframe.py +0 -440
  93. pixeltable/tests/test_dirs.py +0 -107
  94. pixeltable/tests/test_document.py +0 -120
  95. pixeltable/tests/test_exprs.py +0 -802
  96. pixeltable/tests/test_function.py +0 -332
  97. pixeltable/tests/test_index.py +0 -138
  98. pixeltable/tests/test_migration.py +0 -44
  99. pixeltable/tests/test_nos.py +0 -54
  100. pixeltable/tests/test_snapshot.py +0 -231
  101. pixeltable/tests/test_table.py +0 -1343
  102. pixeltable/tests/test_transactional_directory.py +0 -42
  103. pixeltable/tests/test_types.py +0 -52
  104. pixeltable/tests/test_video.py +0 -159
  105. pixeltable/tests/test_view.py +0 -535
  106. pixeltable/tests/utils.py +0 -442
  107. pixeltable/utils/clip.py +0 -18
  108. pixeltable-0.2.5.dist-info/METADATA +0 -128
  109. pixeltable-0.2.5.dist-info/RECORD +0 -139
  110. {pixeltable-0.2.5.dist-info → pixeltable-0.2.7.dist-info}/LICENSE +0 -0
@@ -0,0 +1,526 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Iterator, Optional, Literal
7
+ from xml.etree import ElementTree
8
+
9
+ import PIL.Image
10
+ import label_studio_sdk
11
+ import more_itertools
12
+ from requests.exceptions import HTTPError
13
+
14
+ import pixeltable as pxt
15
+ import pixeltable.env as env
16
+ import pixeltable.exceptions as excs
17
+ from pixeltable import Table
18
+ from pixeltable.datatransfer.remote import Remote
19
+ from pixeltable.exprs import ColumnRef, DataRow
20
+ from pixeltable.utils import coco
21
+
22
+ _logger = logging.getLogger('pixeltable')
23
+
24
+
25
+ @env.register_client('label_studio')
26
+ def _(api_key: str, url: str) -> label_studio_sdk.Client:
27
+ return label_studio_sdk.Client(api_key=api_key, url=url)
28
+
29
+
30
+ def _label_studio_client() -> label_studio_sdk.Client:
31
+ return env.Env.get().get_client('label_studio')
32
+
33
+
34
+ class LabelStudioProject(Remote):
35
+ """
36
+ A [`Remote`][pixeltable.datatransfer.Remote] that represents a Label Studio project, providing functionality
37
+ for synchronizing between a Pixeltable table and a Label Studio project.
38
+
39
+ The API key and URL for a valid Label Studio server must be specified in Pixeltable config. Either:
40
+
41
+ * Set the `LABEL_STUDIO_API_KEY` and `LABEL_STUDIO_URL` environment variables; or
42
+ * Specify `api_key` and `url` fields in the `label-studio` section of `$PIXELTABLE_HOME/config.yaml`.
43
+ """
44
+ # TODO(aaron-siegel): Add link in docstring to a Label Studio howto
45
+
46
+ def __init__(self, project_id: int, media_import_method: Literal['post', 'file']):
47
+ self.project_id = project_id
48
+ self.media_import_method = media_import_method
49
+ self._project: Optional[label_studio_sdk.project.Project] = None
50
+
51
+ @classmethod
52
+ def create(cls, title: str, label_config: str, media_import_method: Literal['post', 'file'] = 'file', **kwargs: Any) -> 'LabelStudioProject':
53
+ """
54
+ Creates a new Label Studio project, using the Label Studio client configured in Pixeltable.
55
+
56
+ Args:
57
+ title: The title of the project.
58
+ label_config: The Label Studio project configuration, in XML format.
59
+ media_import_method: The method to use when importing media columns to Label Studio:
60
+ - `file`: Media will be sent to Label Studio as a file on the local filesystem. This method can be
61
+ used if Pixeltable and Label Studio are running on the same host.
62
+ - `post`: Media will be sent to Label Studio via HTTP post. This should generally only be used for
63
+ prototyping; due to restrictions in Label Studio, it can only be used with projects that have
64
+ just one data field.
65
+ **kwargs: Additional keyword arguments for the new project; these will be passed to `start_project`
66
+ in the Label Studio SDK.
67
+ """
68
+ # TODO(aaron-siegel): Add media_import_method = 'url' as an option
69
+ # Check that the config is valid before creating the project
70
+ config = cls.__parse_project_config(label_config)
71
+ if media_import_method == 'post' and len(config.data_keys) > 1:
72
+ raise excs.Error('`media_import_method` cannot be `post` if there is more than one data key')
73
+
74
+ project = _label_studio_client().start_project(title=title, label_config=label_config, **kwargs)
75
+
76
+ if media_import_method == 'file':
77
+ # We need to set up a local storage connection to receive media files
78
+ os.environ['LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT'] = str(env.Env.get().home)
79
+ try:
80
+ project.connect_local_import_storage(local_store_path=str(env.Env.get().media_dir))
81
+ except HTTPError as exc:
82
+ if exc.errno == 400:
83
+ response: dict = json.loads(exc.response.text)
84
+ if 'validation_errors' in response and 'non_field_errors' in response['validation_errors'] \
85
+ and 'LOCAL_FILES_SERVING_ENABLED' in response['validation_errors']['non_field_errors'][0]:
86
+ raise excs.Error(
87
+ '`media_import_method` is set to `file`, but your Label Studio server is not configured '
88
+ 'for local file storage.\nPlease set the `LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED` '
89
+ 'environment variable to `true` in the environment where your Label Studio server is running.'
90
+ ) from exc
91
+ raise # Handle any other exception type normally
92
+
93
+ project_id = project.get_params()['id']
94
+ return LabelStudioProject(project_id, media_import_method)
95
+
96
+ @property
97
+ def project(self) -> label_studio_sdk.project.Project:
98
+ """The `Project` object corresponding to this Label Studio project."""
99
+ if self._project is None:
100
+ try:
101
+ self._project = _label_studio_client().get_project(self.project_id)
102
+ except HTTPError as exc:
103
+ raise excs.Error(f'Could not locate Label Studio project: {self.project_id} '
104
+ '(cannot connect to server or project no longer exists)') from exc
105
+ return self._project
106
+
107
+ @property
108
+ def project_params(self) -> dict[str, Any]:
109
+ """The parameters of this Label Studio project."""
110
+ return self.project.get_params()
111
+
112
+ @property
113
+ def project_title(self) -> str:
114
+ """The title of this Label Studio project."""
115
+ return self.project_params['title']
116
+
117
+ @property
118
+ def __project_config(self) -> '_LabelStudioConfig':
119
+ return self.__parse_project_config(self.project_params['label_config'])
120
+
121
+ def get_export_columns(self) -> dict[str, pxt.ColumnType]:
122
+ """
123
+ The data keys and preannotation fields specified in this Label Studio project.
124
+ """
125
+ return self.__project_config.export_columns
126
+
127
+ def get_import_columns(self) -> dict[str, pxt.ColumnType]:
128
+ """
129
+ Always contains a single entry:
130
+
131
+ ```
132
+ {"annotations": pxt.JsonType(nullable=True)}
133
+ ```
134
+ """
135
+ return {ANNOTATIONS_COLUMN: pxt.JsonType(nullable=True)}
136
+
137
+ def sync(self, t: Table, col_mapping: dict[str, str], export_data: bool, import_data: bool) -> None:
138
+ _logger.info(f'Syncing Label Studio project "{self.project_title}" with table `{t.get_name()}`'
139
+ f' (export: {export_data}, import: {import_data}).')
140
+ # Collect all existing tasks into a dict with entries `rowid: task`
141
+ tasks = {tuple(task['meta']['rowid']): task for task in self.__fetch_all_tasks()}
142
+ if export_data:
143
+ self.__update_tasks(t, col_mapping, tasks)
144
+ if import_data:
145
+ self.__update_table_from_tasks(t, col_mapping, tasks)
146
+
147
+ def __fetch_all_tasks(self) -> Iterator[dict]:
148
+ page = 1
149
+ unknown_task_count = 0
150
+ while True:
151
+ result = self.project.get_paginated_tasks(page=page, page_size=_PAGE_SIZE)
152
+ if result.get('end_pagination'):
153
+ break
154
+ for task in result['tasks']:
155
+ rowid = task['meta'].get('rowid')
156
+ if rowid is None:
157
+ unknown_task_count += 1
158
+ else:
159
+ yield task
160
+ page += 1
161
+ if unknown_task_count > 0:
162
+ _logger.warning(
163
+ f'Skipped {unknown_task_count} unrecognized task(s) when syncing Label Studio project "{self.project_title}".'
164
+ )
165
+
166
+ def __update_tasks(self, t: Table, col_mapping: dict[str, str], existing_tasks: dict[tuple, dict]) -> None:
167
+
168
+ t_col_types = t.column_types()
169
+ config = self.__project_config
170
+
171
+ # Columns in `t` that map to Label Studio data keys
172
+ t_data_cols = [
173
+ t_col_name for t_col_name, r_col_name in col_mapping.items()
174
+ if r_col_name in config.data_keys
175
+ ]
176
+
177
+ # Columns in `t` that map to `rectanglelabels` preannotations
178
+ t_rl_cols = [
179
+ t_col_name for t_col_name, r_col_name in col_mapping.items()
180
+ if r_col_name in config.rectangle_labels
181
+ ]
182
+
183
+ # Destinations for `rectanglelabels` preannotations
184
+ rl_info = list(config.rectangle_labels.values())
185
+
186
+ _logger.debug('`t_data_cols`: %s', t_data_cols)
187
+ _logger.debug('`t_rl_cols`: %s', t_rl_cols)
188
+ _logger.debug('`rl_info`: %s', rl_info)
189
+
190
+ if self.media_import_method == 'post':
191
+ # Send media to Label Studio by HTTP post.
192
+ self.__update_tasks_by_post(t, col_mapping, existing_tasks, t_data_cols[0], t_rl_cols, rl_info)
193
+ elif self.media_import_method == 'file':
194
+ # Send media to Label Studio by local file transfer.
195
+ self.__update_tasks_by_files(t, col_mapping, existing_tasks, t_data_cols, t_rl_cols, rl_info)
196
+ else:
197
+ assert False
198
+
199
+ def __update_tasks_by_post(
200
+ self,
201
+ t: Table,
202
+ col_mapping: dict[str, str],
203
+ existing_tasks: dict[tuple, dict],
204
+ media_col_name: str,
205
+ t_rl_cols: list[str],
206
+ rl_info: list['_RectangleLabel']
207
+ ) -> None:
208
+ is_stored = t[media_col_name].col.is_stored
209
+ # If it's a stored column, we can use `localpath`
210
+ localpath_col_opt = [t[media_col_name].localpath] if is_stored else []
211
+ # Select the media column, rectanglelabels columns, and localpath (if appropriate)
212
+ rows = t.select(t[media_col_name], *[t[col] for col in t_rl_cols], *localpath_col_opt)
213
+ tasks_created = 0
214
+ row_ids_in_pxt: set[tuple] = set()
215
+
216
+ for row in rows._exec():
217
+ media_col_idx = rows._select_list_exprs[0].slot_idx
218
+ rl_col_idxs = [expr.slot_idx for expr in rows._select_list_exprs[1: 1 + len(t_rl_cols)]]
219
+ row_ids_in_pxt.add(row.rowid)
220
+ if row.rowid not in existing_tasks:
221
+ # Upload the media file to Label Studio
222
+ if is_stored:
223
+ # There is an existing localpath; use it!
224
+ localpath_col_idx = rows._select_list_exprs[-1].slot_idx
225
+ file = Path(row.vals[localpath_col_idx])
226
+ task_id: int = self.project.import_tasks(file)[0]
227
+ else:
228
+ # No localpath; create a temp file and upload it
229
+ assert isinstance(row.vals[media_col_idx], PIL.Image.Image)
230
+ file = env.Env.get().create_tmp_path(extension='.png')
231
+ row.vals[media_col_idx].save(file, format='png')
232
+ task_id: int = self.project.import_tasks(file)[0]
233
+ os.remove(file)
234
+
235
+ # Update the task with `rowid` metadata
236
+ self.project.update_task(task_id, meta={'rowid': row.rowid, 'v_min': row.v_min})
237
+
238
+ # Convert coco annotations to predictions
239
+ coco_annotations = [row.vals[i] for i in rl_col_idxs]
240
+ _logger.debug('`coco_annotations`: %s', coco_annotations)
241
+ predictions = [
242
+ self.__coco_to_predictions(
243
+ coco_annotations[i], col_mapping[t_rl_cols[i]], rl_info[i], task_id=task_id
244
+ )
245
+ for i in range(len(coco_annotations))
246
+ ]
247
+ _logger.debug(f'`predictions`: %s', predictions)
248
+ self.project.create_predictions(predictions)
249
+ tasks_created += 1
250
+
251
+ print(f'Created {tasks_created} new task(s) in {self}.')
252
+
253
+ self.__delete_stale_tasks(existing_tasks, row_ids_in_pxt, tasks_created)
254
+
255
+ def __update_tasks_by_files(
256
+ self,
257
+ t: Table,
258
+ col_mapping: dict[str, str],
259
+ existing_tasks: dict[tuple, dict],
260
+ t_data_cols: list[str],
261
+ t_rl_cols: list[str],
262
+ rl_info: list['_RectangleLabel']
263
+ ) -> None:
264
+ r_data_cols = [col_mapping[col_name] for col_name in t_data_cols]
265
+ col_refs = {}
266
+ for col_name in t_data_cols:
267
+ if not t[col_name].col_type.is_media_type():
268
+ # Not a media column; query the data directly
269
+ col_refs[col_name] = t[col_name]
270
+ elif t[col_name].col.stored_proxy:
271
+ # Media column that has a stored proxy; use it. We have to give it a name,
272
+ # since it's an anonymous column
273
+ col_refs[f'{col_name}_proxy'] = ColumnRef(t[col_name].col.stored_proxy).localpath
274
+ else:
275
+ # Media column without a stored proxy; this means it's a stored computed column,
276
+ # and we can just use the localpath
277
+ col_refs[col_name] = t[col_name].localpath
278
+
279
+ df = t.select(*[t[col] for col in t_rl_cols], **col_refs)
280
+ rl_col_idxs: Optional[list[int]] = None # We have to wait until we begin iterating to populate these
281
+ data_col_idxs: Optional[list[int]] = None
282
+
283
+ row_ids_in_pxt: set[tuple] = set()
284
+ tasks_created = 0
285
+ tasks_updated = 0
286
+ page = []
287
+
288
+ def create_task_info(row: DataRow) -> dict:
289
+ data_vals = [row.vals[idx] for idx in data_col_idxs]
290
+ coco_annotations = [row.vals[idx] for idx in rl_col_idxs]
291
+ # For media columns, we need to transform the paths into Label Studio's bespoke path format
292
+ for i in range(len(t_data_cols)):
293
+ if t[t_data_cols[i]].col_type.is_media_type():
294
+ data_vals[i] = self.__localpath_to_lspath(data_vals[i])
295
+ predictions = [
296
+ self.__coco_to_predictions(coco_annotations[i], col_mapping[t_rl_cols[i]], rl_info[i])
297
+ for i in range(len(coco_annotations))
298
+ ]
299
+ return {
300
+ 'data': dict(zip(r_data_cols, data_vals)),
301
+ 'meta': {'rowid': row.rowid, 'v_min': row.v_min},
302
+ 'predictions': predictions
303
+ }
304
+
305
+ for row in df._exec():
306
+ if rl_col_idxs is None:
307
+ rl_col_idxs = [expr.slot_idx for expr in df._select_list_exprs[:len(t_rl_cols)]]
308
+ data_col_idxs = [expr.slot_idx for expr in df._select_list_exprs[len(t_rl_cols):]]
309
+ row_ids_in_pxt.add(row.rowid)
310
+ if row.rowid in existing_tasks:
311
+ # A task for this row already exists; see if it needs an update.
312
+ # Get the v_min record from task metadata. Default to 0 if no v_min record is found
313
+ old_v_min = int(existing_tasks[row.rowid]['meta'].get('v_min', 0))
314
+ print(f'{old_v_min} {row.v_min}')
315
+ if row.v_min > old_v_min:
316
+ _logger.debug(f'Updating task for rowid {row.rowid} ({row.v_min} > {old_v_min}).')
317
+ task_info = create_task_info(row)
318
+ self.project.update_task(existing_tasks[row.rowid]['id'], **task_info)
319
+ tasks_updated += 1
320
+ else:
321
+ # No task exists for this row; we need to create one.
322
+ page.append(create_task_info(row))
323
+ tasks_created += 1
324
+ if len(page) == _PAGE_SIZE:
325
+ self.project.import_tasks(page)
326
+ page.clear()
327
+
328
+ if len(page) > 0:
329
+ self.project.import_tasks(page)
330
+
331
+ print(f'Created {tasks_created} new task(s) and updated {tasks_updated} existing task(s) in {self}.')
332
+
333
+ self.__delete_stale_tasks(existing_tasks, row_ids_in_pxt, tasks_created)
334
+
335
+ @classmethod
336
+ def __localpath_to_lspath(self, localpath: str) -> str:
337
+ assert isinstance(localpath, str)
338
+ relpath = Path(localpath).relative_to(env.Env.get().home)
339
+ return f'/data/local-files/?d={str(relpath)}'
340
+
341
+ def __delete_stale_tasks(self, existing_tasks: dict[tuple, dict], row_ids_in_pxt: set[tuple], tasks_created: int):
342
+ tasks_to_delete = [
343
+ task['id'] for rowid, task in existing_tasks.items()
344
+ if rowid not in row_ids_in_pxt
345
+ ]
346
+ # Sanity check the math
347
+ assert len(tasks_to_delete) == len(existing_tasks) + tasks_created - len(row_ids_in_pxt)
348
+
349
+ if len(tasks_to_delete) > 0:
350
+ self.project.delete_tasks(tasks_to_delete)
351
+ print(f'Deleted {len(tasks_to_delete)} tasks(s) in {self} that are no longer present in Pixeltable.')
352
+
353
+ def __update_table_from_tasks(self, t: Table, col_mapping: dict[str, str], tasks: dict[tuple, dict]) -> None:
354
+ # `col_mapping` is guaranteed to be a one-to-one dict whose values are a superset
355
+ # of `get_pull_columns`
356
+ assert ANNOTATIONS_COLUMN in col_mapping.values()
357
+ annotations_column = next(k for k, v in col_mapping.items() if v == ANNOTATIONS_COLUMN)
358
+ updates = [
359
+ {
360
+ '_rowid': task['meta']['rowid'],
361
+ # Replace [] by None to indicate no annotations. We do want to sync rows with no annotations,
362
+ # in order to properly handle the scenario where existing annotations have been deleted in
363
+ # Label Studio.
364
+ annotations_column: task[ANNOTATIONS_COLUMN] if len(task[ANNOTATIONS_COLUMN]) > 0 else None
365
+ }
366
+ for task in tasks.values()
367
+ ]
368
+ if len(updates) > 0:
369
+ _logger.info(
370
+ f'Updating table `{t.get_name()}`, column `{annotations_column}` with {len(updates)} total annotations.'
371
+ )
372
+ t.batch_update(updates)
373
+ annotations_count = sum(len(task[ANNOTATIONS_COLUMN]) for task in tasks.values())
374
+ print(f'Synced {annotations_count} annotation(s) from {len(updates)} existing task(s) in {self}.')
375
+
376
+ def to_dict(self) -> dict[str, Any]:
377
+ return {'project_id': self.project_id, 'media_import_method': self.media_import_method}
378
+
379
+ @classmethod
380
+ def from_dict(cls, md: dict[str, Any]) -> 'LabelStudioProject':
381
+ return LabelStudioProject(md['project_id'], md['media_import_method'])
382
+
383
+ def __repr__(self) -> str:
384
+ name = self.project.get_params()['title']
385
+ return f'LabelStudioProject `{name}`'
386
+
387
+ @classmethod
388
+ def __parse_project_config(cls, xml_config: str) -> '_LabelStudioConfig':
389
+ """
390
+ Parses a Label Studio XML config, extracting the names and Pixeltable types of
391
+ all input variables.
392
+ """
393
+ root: ElementTree.Element = ElementTree.fromstring(xml_config)
394
+ if root.tag.lower() != 'view':
395
+ raise excs.Error('Root of Label Studio config must be a `View`')
396
+ config = _LabelStudioConfig(
397
+ data_keys=dict(cls.__parse_data_keys_config(root)),
398
+ rectangle_labels=dict(cls.__parse_rectangle_labels_config(root))
399
+ )
400
+ config.validate()
401
+ return config
402
+
403
+ @classmethod
404
+ def __parse_data_keys_config(cls, root: ElementTree.Element) -> Iterator[tuple[str, '_DataKey']]:
405
+ for element in root:
406
+ if 'value' in element.attrib and element.attrib['value'][0] == '$':
407
+ remote_col_name = element.attrib['value'][1:]
408
+ data_key_name = element.attrib.get('name')
409
+ element_type = _LS_TAG_MAP.get(element.tag.lower())
410
+ if element_type is None:
411
+ raise excs.Error(
412
+ f'Unsupported Label Studio data type: `{element.tag}` (in data key `{remote_col_name}`)'
413
+ )
414
+ yield remote_col_name, _DataKey(data_key_name, element_type)
415
+
416
+ @classmethod
417
+ def __parse_rectangle_labels_config(cls, root: ElementTree.Element) -> Iterator[tuple[str, '_RectangleLabel']]:
418
+ for element in root:
419
+ if element.tag.lower() == 'rectanglelabels':
420
+ name = element.attrib['name']
421
+ to_name = element.attrib['toName']
422
+ labels = [
423
+ child.attrib['value']
424
+ for child in element if child.tag.lower() == 'label'
425
+ ]
426
+ for label in labels:
427
+ if label not in coco.COCO_2017_CATEGORIES.values():
428
+ raise excs.Error(f'Label in `rectanglelabels` config is not a valid COCO object name: {label}')
429
+ yield name, _RectangleLabel(to_name=to_name, labels=labels)
430
+
431
+ @classmethod
432
+ def __coco_to_predictions(
433
+ cls,
434
+ coco_annotations: dict[str, Any],
435
+ from_name: str,
436
+ rl_info: '_RectangleLabel',
437
+ task_id: Optional[int] = None
438
+ ) -> dict[str, Any]:
439
+ width = coco_annotations['image']['width']
440
+ height = coco_annotations['image']['height']
441
+ result = [
442
+ {
443
+ 'id': f'result_{i}',
444
+ 'type': 'rectanglelabels',
445
+ 'from_name': from_name,
446
+ 'to_name': rl_info.to_name,
447
+ 'image_rotation': 0,
448
+ 'original_width': width,
449
+ 'original_height': height,
450
+ 'value': {
451
+ 'rotation': 0,
452
+ # Label Studio expects image coordinates as % of image dimensions
453
+ 'x': entry['bbox'][0] * 100.0 / width,
454
+ 'y': entry['bbox'][1] * 100.0 / height,
455
+ 'width': entry['bbox'][2] * 100.0 / width,
456
+ 'height': entry['bbox'][3] * 100.0 / height,
457
+ 'rectanglelabels': [coco.COCO_2017_CATEGORIES[entry['category']]]
458
+ }
459
+ }
460
+ for i, entry in enumerate(coco_annotations['annotations'])
461
+ # include only the COCO labels that match a rectanglelabel name
462
+ if coco.COCO_2017_CATEGORIES[entry['category']] in rl_info.labels
463
+ ]
464
+ if task_id is not None:
465
+ return {'task': task_id, 'result': result}
466
+ else:
467
+ return {'result': result}
468
+
469
+ def delete(self) -> None:
470
+ """
471
+ Deletes this Label Studio project. This will remove all data and annotations
472
+ associated with this project in Label Studio.
473
+ """
474
+ title = self.project_title
475
+ _label_studio_client().delete_project(self.project_id)
476
+ print(f'Deleted Label Studio project: {title}')
477
+
478
+ def __eq__(self, other) -> bool:
479
+ return isinstance(other, LabelStudioProject) and self.project_id == other.project_id
480
+
481
+ def __hash__(self) -> int:
482
+ return hash(self.project_id)
483
+
484
+
485
+ @dataclass(frozen=True)
486
+ class _DataKey:
487
+ name: Optional[str] # The 'name' attribute of the data key; may differ from the field name
488
+ column_type: pxt.ColumnType
489
+
490
+
491
+ @dataclass(frozen=True)
492
+ class _RectangleLabel:
493
+ to_name: str
494
+ labels: list[str]
495
+
496
+
497
+ @dataclass(frozen=True)
498
+ class _LabelStudioConfig:
499
+ data_keys: dict[str, _DataKey]
500
+ rectangle_labels: dict[str, _RectangleLabel]
501
+
502
+ def validate(self) -> None:
503
+ data_key_names = set(key.name for key in self.data_keys.values() if key is not None)
504
+ for name, rl in self.rectangle_labels.items():
505
+ if rl.to_name not in data_key_names:
506
+ raise excs.Error(
507
+ f'Invalid Label Studio configuration: `toName` attribute of RectangleLabels `{name}` '
508
+ f'references an unknown data key: `{rl.to_name}`'
509
+ )
510
+
511
+ @property
512
+ def export_columns(self) -> dict[str, pxt.ColumnType]:
513
+ data_key_cols = {key_id: key_info.column_type for key_id, key_info in self.data_keys.items()}
514
+ rl_cols = {name: pxt.JsonType() for name in self.rectangle_labels.keys()}
515
+ return {**data_key_cols, **rl_cols}
516
+
517
+
518
+ ANNOTATIONS_COLUMN = 'annotations'
519
+ _PAGE_SIZE = 100 # This is the default used in the LS SDK
520
+ _LS_TAG_MAP = {
521
+ 'header': pxt.StringType(),
522
+ 'text': pxt.StringType(),
523
+ 'image': pxt.ImageType(),
524
+ 'video': pxt.VideoType(),
525
+ 'audio': pxt.AudioType()
526
+ }
@@ -0,0 +1,113 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import Any
5
+
6
+ import pixeltable.type_system as ts
7
+ from pixeltable import Table
8
+
9
+
10
+ class Remote(abc.ABC):
11
+ """
12
+ Abstract base class that represents a remote data store. Subclasses of `Remote` provide
13
+ functionality for synchronizing between Pixeltable tables and stateful remote stores.
14
+ """
15
+
16
+ @abc.abstractmethod
17
+ def get_export_columns(self) -> dict[str, ts.ColumnType]:
18
+ """
19
+ Returns the names and Pixeltable types that this `Remote` expects to see in a data export.
20
+
21
+ Returns:
22
+ A `dict` mapping names of expected columns to their Pixeltable types.
23
+ """
24
+
25
+ @abc.abstractmethod
26
+ def get_import_columns(self) -> dict[str, ts.ColumnType]:
27
+ """
28
+ Returns the names and Pixeltable types that this `Remote` provides in a data import.
29
+
30
+ Returns:
31
+ A `dict` mapping names of provided columns to their Pixeltable types.
32
+ """
33
+
34
+ @abc.abstractmethod
35
+ def sync(self, t: Table, col_mapping: dict[str, str], export_data: bool, import_data: bool) -> None:
36
+ """
37
+ Synchronizes the given [`Table`][pixeltable.Table] with this `Remote`. This method
38
+ should generally not be called directly; instead, call
39
+ [`t.sync()`][pixeltable.Table.sync].
40
+
41
+ Args:
42
+ t: The table to synchronize with this remote.
43
+ col_mapping: A `dict` mapping columns in the Pixeltable table to columns in the remote store.
44
+ export_data: If `True`, data from this table will be exported to the remote during synchronization.
45
+ import_data: If `True`, data from this table will be imported from the remote during synchronization.
46
+ """
47
+
48
+ @abc.abstractmethod
49
+ def delete(self) -> None:
50
+ """
51
+ Deletes this `Remote`.
52
+ """
53
+
54
+ @abc.abstractmethod
55
+ def to_dict(self) -> dict[str, Any]: ...
56
+
57
+ @classmethod
58
+ @abc.abstractmethod
59
+ def from_dict(cls, md: dict[str, Any]) -> Remote: ...
60
+
61
+
62
+ # A remote that cannot be synced, used mainly for testing.
63
+ class MockRemote(Remote):
64
+
65
+ def __init__(self, name: str, export_cols: dict[str, ts.ColumnType], import_cols: dict[str, ts.ColumnType]):
66
+ self.name = name
67
+ self.export_cols = export_cols
68
+ self.import_cols = import_cols
69
+ self.__is_deleted = False
70
+
71
+ def get_export_columns(self) -> dict[str, ts.ColumnType]:
72
+ return self.export_cols
73
+
74
+ def get_import_columns(self) -> dict[str, ts.ColumnType]:
75
+ return self.import_cols
76
+
77
+ def sync(self, t: Table, col_mapping: dict[str, str], export_data: bool, import_data: bool) -> NotImplemented:
78
+ raise NotImplementedError()
79
+
80
+ def delete(self) -> None:
81
+ self.__is_deleted = True
82
+
83
+ @property
84
+ def is_deleted(self) -> bool:
85
+ return self.__is_deleted
86
+
87
+ def to_dict(self) -> dict[str, Any]:
88
+ return {
89
+ # TODO Change in next schema version
90
+ 'name': self.name,
91
+ 'push_cols': {k: v.as_dict() for k, v in self.export_cols.items()},
92
+ 'pull_cols': {k: v.as_dict() for k, v in self.import_cols.items()}
93
+ }
94
+
95
+ @classmethod
96
+ def from_dict(cls, md: dict[str, Any]) -> Remote:
97
+ return cls(
98
+ name=md['name'],
99
+ # TODO Change in next schema version
100
+ export_cols={k: ts.ColumnType.from_dict(v) for k, v in md['push_cols'].items()},
101
+ import_cols={k: ts.ColumnType.from_dict(v) for k, v in md['pull_cols'].items()}
102
+ )
103
+
104
+ def __eq__(self, other: Any) -> bool:
105
+ if not isinstance(other, MockRemote):
106
+ return False
107
+ return self.name == other.name
108
+
109
+ def __hash__(self) -> int:
110
+ return hash(self.name)
111
+
112
+ def __repr__(self) -> str:
113
+ return f'MockRemote `{self.name}`'