pixeltable 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (76) hide show
  1. pixeltable/__init__.py +15 -33
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +1 -1
  4. pixeltable/catalog/column.py +28 -16
  5. pixeltable/catalog/dir.py +2 -2
  6. pixeltable/catalog/insertable_table.py +5 -55
  7. pixeltable/catalog/named_function.py +2 -2
  8. pixeltable/catalog/schema_object.py +2 -7
  9. pixeltable/catalog/table.py +298 -204
  10. pixeltable/catalog/table_version.py +104 -139
  11. pixeltable/catalog/table_version_path.py +22 -4
  12. pixeltable/catalog/view.py +20 -10
  13. pixeltable/dataframe.py +128 -25
  14. pixeltable/env.py +21 -14
  15. pixeltable/exec/exec_context.py +5 -0
  16. pixeltable/exec/exec_node.py +1 -0
  17. pixeltable/exec/in_memory_data_node.py +29 -24
  18. pixeltable/exec/sql_scan_node.py +1 -1
  19. pixeltable/exprs/column_ref.py +13 -8
  20. pixeltable/exprs/data_row.py +4 -0
  21. pixeltable/exprs/expr.py +16 -1
  22. pixeltable/exprs/function_call.py +4 -4
  23. pixeltable/exprs/row_builder.py +29 -20
  24. pixeltable/exprs/similarity_expr.py +4 -3
  25. pixeltable/ext/functions/yolox.py +2 -1
  26. pixeltable/func/__init__.py +1 -0
  27. pixeltable/func/aggregate_function.py +14 -12
  28. pixeltable/func/callable_function.py +8 -6
  29. pixeltable/func/expr_template_function.py +13 -19
  30. pixeltable/func/function.py +3 -6
  31. pixeltable/func/query_template_function.py +84 -0
  32. pixeltable/func/signature.py +68 -23
  33. pixeltable/func/udf.py +13 -10
  34. pixeltable/functions/__init__.py +6 -91
  35. pixeltable/functions/eval.py +26 -14
  36. pixeltable/functions/fireworks.py +25 -23
  37. pixeltable/functions/globals.py +62 -0
  38. pixeltable/functions/huggingface.py +20 -16
  39. pixeltable/functions/image.py +170 -1
  40. pixeltable/functions/openai.py +95 -128
  41. pixeltable/functions/string.py +10 -2
  42. pixeltable/functions/together.py +95 -84
  43. pixeltable/functions/util.py +16 -0
  44. pixeltable/functions/video.py +94 -16
  45. pixeltable/functions/whisper.py +78 -0
  46. pixeltable/globals.py +1 -1
  47. pixeltable/io/__init__.py +10 -0
  48. pixeltable/io/external_store.py +370 -0
  49. pixeltable/io/globals.py +50 -22
  50. pixeltable/{datatransfer → io}/label_studio.py +279 -166
  51. pixeltable/io/parquet.py +1 -1
  52. pixeltable/iterators/__init__.py +9 -0
  53. pixeltable/iterators/string.py +40 -0
  54. pixeltable/metadata/__init__.py +6 -8
  55. pixeltable/metadata/converters/convert_10.py +2 -4
  56. pixeltable/metadata/converters/convert_12.py +7 -2
  57. pixeltable/metadata/converters/convert_13.py +6 -8
  58. pixeltable/metadata/converters/convert_14.py +2 -4
  59. pixeltable/metadata/converters/convert_15.py +40 -25
  60. pixeltable/metadata/converters/convert_16.py +18 -0
  61. pixeltable/metadata/converters/util.py +11 -8
  62. pixeltable/metadata/schema.py +3 -6
  63. pixeltable/plan.py +8 -7
  64. pixeltable/store.py +1 -1
  65. pixeltable/tool/create_test_db_dump.py +145 -54
  66. pixeltable/tool/embed_udf.py +9 -0
  67. pixeltable/type_system.py +1 -2
  68. pixeltable/utils/code.py +34 -0
  69. {pixeltable-0.2.7.dist-info → pixeltable-0.2.9.dist-info}/METADATA +2 -2
  70. pixeltable-0.2.9.dist-info/RECORD +131 -0
  71. pixeltable/datatransfer/__init__.py +0 -1
  72. pixeltable/datatransfer/remote.py +0 -113
  73. pixeltable/functions/pil/image.py +0 -147
  74. pixeltable-0.2.7.dist-info/RECORD +0 -126
  75. {pixeltable-0.2.7.dist-info → pixeltable-0.2.9.dist-info}/LICENSE +0 -0
  76. {pixeltable-0.2.7.dist-info → pixeltable-0.2.9.dist-info}/WHEEL +0 -0
@@ -8,15 +8,14 @@ from xml.etree import ElementTree
8
8
 
9
9
  import PIL.Image
10
10
  import label_studio_sdk
11
- import more_itertools
12
11
  from requests.exceptions import HTTPError
13
12
 
14
13
  import pixeltable as pxt
15
14
  import pixeltable.env as env
16
15
  import pixeltable.exceptions as excs
17
- from pixeltable import Table
18
- from pixeltable.datatransfer.remote import Remote
19
- from pixeltable.exprs import ColumnRef, DataRow
16
+ from pixeltable import Table, Column
17
+ from pixeltable.exprs import ColumnRef, DataRow, Expr
18
+ from pixeltable.io.external_store import Project, SyncStatus
20
19
  from pixeltable.utils import coco
21
20
 
22
21
  _logger = logging.getLogger('pixeltable')
@@ -31,67 +30,28 @@ def _label_studio_client() -> label_studio_sdk.Client:
31
30
  return env.Env.get().get_client('label_studio')
32
31
 
33
32
 
34
- class LabelStudioProject(Remote):
33
+ class LabelStudioProject(Project):
35
34
  """
36
- A [`Remote`][pixeltable.datatransfer.Remote] that represents a Label Studio project, providing functionality
35
+ An [`ExternalStore`][pixeltable.io.ExternalStore] that represents a Label Studio project, providing functionality
37
36
  for synchronizing between a Pixeltable table and a Label Studio project.
38
-
39
- The API key and URL for a valid Label Studio server must be specified in Pixeltable config. Either:
40
-
41
- * Set the `LABEL_STUDIO_API_KEY` and `LABEL_STUDIO_URL` environment variables; or
42
- * Specify `api_key` and `url` fields in the `label-studio` section of `$PIXELTABLE_HOME/config.yaml`.
43
37
  """
44
- # TODO(aaron-siegel): Add link in docstring to a Label Studio howto
45
38
 
46
- def __init__(self, project_id: int, media_import_method: Literal['post', 'file']):
39
+ def __init__(
40
+ self,
41
+ name: str,
42
+ project_id: int,
43
+ media_import_method: Literal['post', 'file', 'url'],
44
+ col_mapping: dict[Column, str],
45
+ stored_proxies: Optional[dict[Column, Column]] = None
46
+ ):
47
+ """
48
+ The constructor will NOT create a new Label Studio project; it is also used when loading
49
+ metadata for existing projects.
50
+ """
47
51
  self.project_id = project_id
48
52
  self.media_import_method = media_import_method
49
53
  self._project: Optional[label_studio_sdk.project.Project] = None
50
-
51
- @classmethod
52
- def create(cls, title: str, label_config: str, media_import_method: Literal['post', 'file'] = 'file', **kwargs: Any) -> 'LabelStudioProject':
53
- """
54
- Creates a new Label Studio project, using the Label Studio client configured in Pixeltable.
55
-
56
- Args:
57
- title: The title of the project.
58
- label_config: The Label Studio project configuration, in XML format.
59
- media_import_method: The method to use when importing media columns to Label Studio:
60
- - `file`: Media will be sent to Label Studio as a file on the local filesystem. This method can be
61
- used if Pixeltable and Label Studio are running on the same host.
62
- - `post`: Media will be sent to Label Studio via HTTP post. This should generally only be used for
63
- prototyping; due to restrictions in Label Studio, it can only be used with projects that have
64
- just one data field.
65
- **kwargs: Additional keyword arguments for the new project; these will be passed to `start_project`
66
- in the Label Studio SDK.
67
- """
68
- # TODO(aaron-siegel): Add media_import_method = 'url' as an option
69
- # Check that the config is valid before creating the project
70
- config = cls.__parse_project_config(label_config)
71
- if media_import_method == 'post' and len(config.data_keys) > 1:
72
- raise excs.Error('`media_import_method` cannot be `post` if there is more than one data key')
73
-
74
- project = _label_studio_client().start_project(title=title, label_config=label_config, **kwargs)
75
-
76
- if media_import_method == 'file':
77
- # We need to set up a local storage connection to receive media files
78
- os.environ['LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT'] = str(env.Env.get().home)
79
- try:
80
- project.connect_local_import_storage(local_store_path=str(env.Env.get().media_dir))
81
- except HTTPError as exc:
82
- if exc.errno == 400:
83
- response: dict = json.loads(exc.response.text)
84
- if 'validation_errors' in response and 'non_field_errors' in response['validation_errors'] \
85
- and 'LOCAL_FILES_SERVING_ENABLED' in response['validation_errors']['non_field_errors'][0]:
86
- raise excs.Error(
87
- '`media_import_method` is set to `file`, but your Label Studio server is not configured '
88
- 'for local file storage.\nPlease set the `LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED` '
89
- 'environment variable to `true` in the environment where your Label Studio server is running.'
90
- ) from exc
91
- raise # Handle any other exception type normally
92
-
93
- project_id = project.get_params()['id']
94
- return LabelStudioProject(project_id, media_import_method)
54
+ super().__init__(name, col_mapping, stored_proxies)
95
55
 
96
56
  @property
97
57
  def project(self) -> label_studio_sdk.project.Project:
@@ -134,17 +94,22 @@ class LabelStudioProject(Remote):
134
94
  """
135
95
  return {ANNOTATIONS_COLUMN: pxt.JsonType(nullable=True)}
136
96
 
137
- def sync(self, t: Table, col_mapping: dict[str, str], export_data: bool, import_data: bool) -> None:
97
+ def sync(self, t: Table, export_data: bool, import_data: bool) -> SyncStatus:
138
98
  _logger.info(f'Syncing Label Studio project "{self.project_title}" with table `{t.get_name()}`'
139
99
  f' (export: {export_data}, import: {import_data}).')
140
100
  # Collect all existing tasks into a dict with entries `rowid: task`
141
101
  tasks = {tuple(task['meta']['rowid']): task for task in self.__fetch_all_tasks()}
102
+ sync_status = SyncStatus.empty()
142
103
  if export_data:
143
- self.__update_tasks(t, col_mapping, tasks)
104
+ export_sync_status = self.__update_tasks(t, tasks)
105
+ sync_status = sync_status.combine(export_sync_status)
144
106
  if import_data:
145
- self.__update_table_from_tasks(t, col_mapping, tasks)
107
+ import_sync_status = self.__update_table_from_tasks(t, tasks)
108
+ sync_status = sync_status.combine(import_sync_status)
109
+ return sync_status
146
110
 
147
- def __fetch_all_tasks(self) -> Iterator[dict]:
111
+ def __fetch_all_tasks(self) -> Iterator[dict[str, Any]]:
112
+ """Retrieves all tasks and task metadata in this Label Studio project."""
148
113
  page = 1
149
114
  unknown_task_count = 0
150
115
  while True:
@@ -163,21 +128,28 @@ class LabelStudioProject(Remote):
163
128
  f'Skipped {unknown_task_count} unrecognized task(s) when syncing Label Studio project "{self.project_title}".'
164
129
  )
165
130
 
166
- def __update_tasks(self, t: Table, col_mapping: dict[str, str], existing_tasks: dict[tuple, dict]) -> None:
167
-
168
- t_col_types = t.column_types()
131
+ def __update_tasks(self, t: Table, existing_tasks: dict[tuple, dict]) -> SyncStatus:
132
+ """
133
+ Updates all tasks in this Label Studio project based on the Pixeltable data:
134
+ - Creates new tasks for rows that don't map to any existing task;
135
+ - Updates existing tasks for rows whose data has changed;
136
+ - Deletes any tasks whose rows no longer exist in the Pixeltable table.
137
+ """
169
138
  config = self.__project_config
170
139
 
171
140
  # Columns in `t` that map to Label Studio data keys
172
141
  t_data_cols = [
173
- t_col_name for t_col_name, r_col_name in col_mapping.items()
174
- if r_col_name in config.data_keys
142
+ t_col for t_col, ext_col_name in self.col_mapping.items()
143
+ if ext_col_name in config.data_keys
175
144
  ]
176
145
 
146
+ if len(t_data_cols) == 0:
147
+ return SyncStatus.empty()
148
+
177
149
  # Columns in `t` that map to `rectanglelabels` preannotations
178
150
  t_rl_cols = [
179
- t_col_name for t_col_name, r_col_name in col_mapping.items()
180
- if r_col_name in config.rectangle_labels
151
+ t_col for t_col, ext_col_name in self.col_mapping.items()
152
+ if ext_col_name in config.rectangle_labels
181
153
  ]
182
154
 
183
155
  # Destinations for `rectanglelabels` preannotations
@@ -189,27 +161,27 @@ class LabelStudioProject(Remote):
189
161
 
190
162
  if self.media_import_method == 'post':
191
163
  # Send media to Label Studio by HTTP post.
192
- self.__update_tasks_by_post(t, col_mapping, existing_tasks, t_data_cols[0], t_rl_cols, rl_info)
193
- elif self.media_import_method == 'file':
194
- # Send media to Label Studio by local file transfer.
195
- self.__update_tasks_by_files(t, col_mapping, existing_tasks, t_data_cols, t_rl_cols, rl_info)
164
+ assert len(t_data_cols) == 1 # This was verified when the project was set up
165
+ return self.__update_tasks_by_post(t, existing_tasks, t_data_cols[0], t_rl_cols, rl_info)
166
+ elif self.media_import_method == 'file' or self.media_import_method == 'url':
167
+ # Send media to Label Studio by file reference (local file or URL).
168
+ return self.__update_tasks_by_files(t, existing_tasks, t_data_cols, t_rl_cols, rl_info)
196
169
  else:
197
170
  assert False
198
171
 
199
172
  def __update_tasks_by_post(
200
173
  self,
201
174
  t: Table,
202
- col_mapping: dict[str, str],
203
175
  existing_tasks: dict[tuple, dict],
204
- media_col_name: str,
205
- t_rl_cols: list[str],
176
+ media_col: Column,
177
+ t_rl_cols: list[Column],
206
178
  rl_info: list['_RectangleLabel']
207
- ) -> None:
208
- is_stored = t[media_col_name].col.is_stored
179
+ ) -> SyncStatus:
180
+ is_stored = media_col.is_stored
209
181
  # If it's a stored column, we can use `localpath`
210
- localpath_col_opt = [t[media_col_name].localpath] if is_stored else []
182
+ localpath_col_opt = [t[media_col.name].localpath] if is_stored else []
211
183
  # Select the media column, rectanglelabels columns, and localpath (if appropriate)
212
- rows = t.select(t[media_col_name], *[t[col] for col in t_rl_cols], *localpath_col_opt)
184
+ rows = t.select(t[media_col.name], *[t[col.name] for col in t_rl_cols], *localpath_col_opt)
213
185
  tasks_created = 0
214
186
  row_ids_in_pxt: set[tuple] = set()
215
187
 
@@ -222,25 +194,25 @@ class LabelStudioProject(Remote):
222
194
  if is_stored:
223
195
  # There is an existing localpath; use it!
224
196
  localpath_col_idx = rows._select_list_exprs[-1].slot_idx
225
- file = Path(row.vals[localpath_col_idx])
197
+ file = Path(row[localpath_col_idx])
226
198
  task_id: int = self.project.import_tasks(file)[0]
227
199
  else:
228
200
  # No localpath; create a temp file and upload it
229
- assert isinstance(row.vals[media_col_idx], PIL.Image.Image)
201
+ assert isinstance(row[media_col_idx], PIL.Image.Image)
230
202
  file = env.Env.get().create_tmp_path(extension='.png')
231
- row.vals[media_col_idx].save(file, format='png')
203
+ row[media_col_idx].save(file, format='png')
232
204
  task_id: int = self.project.import_tasks(file)[0]
233
205
  os.remove(file)
234
206
 
235
207
  # Update the task with `rowid` metadata
236
- self.project.update_task(task_id, meta={'rowid': row.rowid, 'v_min': row.v_min})
208
+ self.project.update_task(task_id, meta={'rowid': row.rowid})
237
209
 
238
210
  # Convert coco annotations to predictions
239
- coco_annotations = [row.vals[i] for i in rl_col_idxs]
211
+ coco_annotations = [row[i] for i in rl_col_idxs]
240
212
  _logger.debug('`coco_annotations`: %s', coco_annotations)
241
213
  predictions = [
242
214
  self.__coco_to_predictions(
243
- coco_annotations[i], col_mapping[t_rl_cols[i]], rl_info[i], task_id=task_id
215
+ coco_annotations[i], self.col_mapping[t_rl_cols[i]], rl_info[i], task_id=task_id
244
216
  )
245
217
  for i in range(len(coco_annotations))
246
218
  ]
@@ -250,55 +222,75 @@ class LabelStudioProject(Remote):
250
222
 
251
223
  print(f'Created {tasks_created} new task(s) in {self}.')
252
224
 
253
- self.__delete_stale_tasks(existing_tasks, row_ids_in_pxt, tasks_created)
225
+ sync_status = SyncStatus(external_rows_created=tasks_created)
226
+
227
+ deletion_sync_status = self.__delete_stale_tasks(existing_tasks, row_ids_in_pxt, tasks_created)
228
+
229
+ return sync_status.combine(deletion_sync_status)
254
230
 
255
231
  def __update_tasks_by_files(
256
232
  self,
257
233
  t: Table,
258
- col_mapping: dict[str, str],
259
234
  existing_tasks: dict[tuple, dict],
260
- t_data_cols: list[str],
261
- t_rl_cols: list[str],
235
+ t_data_cols: list[Column],
236
+ t_rl_cols: list[Column],
262
237
  rl_info: list['_RectangleLabel']
263
- ) -> None:
264
- r_data_cols = [col_mapping[col_name] for col_name in t_data_cols]
265
- col_refs = {}
266
- for col_name in t_data_cols:
267
- if not t[col_name].col_type.is_media_type():
268
- # Not a media column; query the data directly
269
- col_refs[col_name] = t[col_name]
270
- elif t[col_name].col.stored_proxy:
271
- # Media column that has a stored proxy; use it. We have to give it a name,
272
- # since it's an anonymous column
273
- col_refs[f'{col_name}_proxy'] = ColumnRef(t[col_name].col.stored_proxy).localpath
238
+ ) -> SyncStatus:
239
+ ext_data_cols = [self.col_mapping[col] for col in t_data_cols]
240
+ expr_refs: dict[str, Expr] = {} # kwargs for the select statement
241
+ for col in t_data_cols:
242
+ col_name = col.name
243
+ if self.media_import_method == 'url':
244
+ expr_refs[col_name] = t[col_name].fileurl
274
245
  else:
275
- # Media column without a stored proxy; this means it's a stored computed column,
276
- # and we can just use the localpath
277
- col_refs[col_name] = t[col_name].localpath
278
-
279
- df = t.select(*[t[col] for col in t_rl_cols], **col_refs)
280
- rl_col_idxs: Optional[list[int]] = None # We have to wait until we begin iterating to populate these
246
+ assert self.media_import_method == 'file'
247
+ if not col.col_type.is_media_type():
248
+ # Not a media column; query the data directly
249
+ expr_refs[col_name] = t[col_name]
250
+ elif col in self.stored_proxies:
251
+ # Media column that has a stored proxy; use it. We have to give it a name,
252
+ # since it's an anonymous column
253
+ stored_proxy_col = self.stored_proxies[col]
254
+ expr_refs[f'{col_name}_proxy'] = ColumnRef(stored_proxy_col).localpath
255
+ else:
256
+ # Media column without a stored proxy; this means it's a stored computed column,
257
+ # and we can just use the localpath
258
+ expr_refs[col_name] = t[col_name].localpath
259
+
260
+ df = t.select(*[t[col] for col in t_rl_cols], **expr_refs)
261
+ # The following buffers will hold `DataRow` indices that correspond to each of the selected
262
+ # columns. `rl_col_idxs` holds the indices for the columns that map to RectangleLabels
263
+ # preannotations; `data_col_idxs` holds the indices for the columns that map to data fields.
264
+ # We have to wait until we begin iterating to populate them, so they're initially `None`.
265
+ rl_col_idxs: Optional[list[int]] = None
281
266
  data_col_idxs: Optional[list[int]] = None
282
267
 
283
268
  row_ids_in_pxt: set[tuple] = set()
284
269
  tasks_created = 0
285
270
  tasks_updated = 0
286
- page = []
271
+ page: list[dict[str, Any]] = [] # buffer to hold tasks for paginated API calls
287
272
 
288
- def create_task_info(row: DataRow) -> dict:
289
- data_vals = [row.vals[idx] for idx in data_col_idxs]
290
- coco_annotations = [row.vals[idx] for idx in rl_col_idxs]
291
- # For media columns, we need to transform the paths into Label Studio's bespoke path format
273
+ # Function that turns a `DataRow` into a `dict` for creating or updating a task in the
274
+ # Label Studio SDK.
275
+ def create_task_info(row: DataRow) -> dict[str, Any]:
276
+ data_vals = [row[idx] for idx in data_col_idxs]
277
+ coco_annotations = [row[idx] for idx in rl_col_idxs]
292
278
  for i in range(len(t_data_cols)):
293
- if t[t_data_cols[i]].col_type.is_media_type():
294
- data_vals[i] = self.__localpath_to_lspath(data_vals[i])
279
+ if t_data_cols[i].col_type.is_media_type():
280
+ # Special handling for media columns
281
+ assert isinstance(data_vals[i], str)
282
+ if self.media_import_method == 'url':
283
+ data_vals[i] = self.__validate_fileurl(t_data_cols[i], data_vals[i])
284
+ else:
285
+ assert self.media_import_method == 'file'
286
+ data_vals[i] = self.__localpath_to_lspath(data_vals[i])
295
287
  predictions = [
296
- self.__coco_to_predictions(coco_annotations[i], col_mapping[t_rl_cols[i]], rl_info[i])
288
+ self.__coco_to_predictions(coco_annotations[i], self.col_mapping[t_rl_cols[i]], rl_info[i])
297
289
  for i in range(len(coco_annotations))
298
290
  ]
299
291
  return {
300
- 'data': dict(zip(r_data_cols, data_vals)),
301
- 'meta': {'rowid': row.rowid, 'v_min': row.v_min},
292
+ 'data': dict(zip(ext_data_cols, data_vals)),
293
+ 'meta': {'rowid': row.rowid},
302
294
  'predictions': predictions
303
295
  }
304
296
 
@@ -307,19 +299,19 @@ class LabelStudioProject(Remote):
307
299
  rl_col_idxs = [expr.slot_idx for expr in df._select_list_exprs[:len(t_rl_cols)]]
308
300
  data_col_idxs = [expr.slot_idx for expr in df._select_list_exprs[len(t_rl_cols):]]
309
301
  row_ids_in_pxt.add(row.rowid)
302
+ task_info = create_task_info(row)
303
+ # TODO(aaron-siegel): Implement more efficient update logic (currently involves a full table scan)
310
304
  if row.rowid in existing_tasks:
311
305
  # A task for this row already exists; see if it needs an update.
312
- # Get the v_min record from task metadata. Default to 0 if no v_min record is found
313
- old_v_min = int(existing_tasks[row.rowid]['meta'].get('v_min', 0))
314
- print(f'{old_v_min} {row.v_min}')
315
- if row.v_min > old_v_min:
316
- _logger.debug(f'Updating task for rowid {row.rowid} ({row.v_min} > {old_v_min}).')
317
- task_info = create_task_info(row)
306
+ existing_task = existing_tasks[row.rowid]
307
+ if task_info['data'] != existing_task['data'] or \
308
+ task_info['predictions'] != existing_task['predictions']:
309
+ _logger.debug(f'Updating task for rowid {row.rowid}.')
318
310
  self.project.update_task(existing_tasks[row.rowid]['id'], **task_info)
319
311
  tasks_updated += 1
320
312
  else:
321
313
  # No task exists for this row; we need to create one.
322
- page.append(create_task_info(row))
314
+ page.append(task_info)
323
315
  tasks_created += 1
324
316
  if len(page) == _PAGE_SIZE:
325
317
  self.project.import_tasks(page)
@@ -330,55 +322,103 @@ class LabelStudioProject(Remote):
330
322
 
331
323
  print(f'Created {tasks_created} new task(s) and updated {tasks_updated} existing task(s) in {self}.')
332
324
 
333
- self.__delete_stale_tasks(existing_tasks, row_ids_in_pxt, tasks_created)
325
+ sync_status = SyncStatus(external_rows_created=tasks_created, external_rows_updated=tasks_updated)
326
+
327
+ deletion_sync_status = self.__delete_stale_tasks(existing_tasks, row_ids_in_pxt, tasks_created)
328
+
329
+ return sync_status.combine(deletion_sync_status)
330
+
331
+ @classmethod
332
+ def __validate_fileurl(cls, col: Column, url: str) -> Optional[str]:
333
+ # Check that the URL is one that will be visible to Label Studio. If it isn't, log an info message
334
+ # to help users debug the issue.
335
+ if not (url.startswith('http://') or url.startswith('https://')):
336
+ _logger.info(
337
+ f'URL found in media column `{col.name}` will not render correctly in Label Studio, since '
338
+ f'it is not an HTTP URL: {url}'
339
+ )
340
+ return url
334
341
 
335
342
  @classmethod
336
- def __localpath_to_lspath(self, localpath: str) -> str:
337
- assert isinstance(localpath, str)
343
+ def __localpath_to_lspath(cls, localpath: str) -> str:
344
+ # Transform the local path into Label Studio's bespoke path format.
338
345
  relpath = Path(localpath).relative_to(env.Env.get().home)
339
346
  return f'/data/local-files/?d={str(relpath)}'
340
347
 
341
- def __delete_stale_tasks(self, existing_tasks: dict[tuple, dict], row_ids_in_pxt: set[tuple], tasks_created: int):
342
- tasks_to_delete = [
343
- task['id'] for rowid, task in existing_tasks.items()
344
- if rowid not in row_ids_in_pxt
345
- ]
348
+ def __delete_stale_tasks(self, existing_tasks: dict[tuple, dict], row_ids_in_pxt: set[tuple], tasks_created: int) -> SyncStatus:
349
+ deleted_rowids = set(existing_tasks.keys()) - row_ids_in_pxt
346
350
  # Sanity check the math
347
- assert len(tasks_to_delete) == len(existing_tasks) + tasks_created - len(row_ids_in_pxt)
351
+ assert len(deleted_rowids) == len(existing_tasks) + tasks_created - len(row_ids_in_pxt)
352
+ tasks_to_delete = [existing_tasks[rowid]['id'] for rowid in deleted_rowids]
348
353
 
349
354
  if len(tasks_to_delete) > 0:
350
355
  self.project.delete_tasks(tasks_to_delete)
351
356
  print(f'Deleted {len(tasks_to_delete)} tasks(s) in {self} that are no longer present in Pixeltable.')
352
357
 
353
- def __update_table_from_tasks(self, t: Table, col_mapping: dict[str, str], tasks: dict[tuple, dict]) -> None:
354
- # `col_mapping` is guaranteed to be a one-to-one dict whose values are a superset
355
- # of `get_pull_columns`
356
- assert ANNOTATIONS_COLUMN in col_mapping.values()
357
- annotations_column = next(k for k, v in col_mapping.items() if v == ANNOTATIONS_COLUMN)
358
- updates = [
359
- {
360
- '_rowid': task['meta']['rowid'],
361
- # Replace [] by None to indicate no annotations. We do want to sync rows with no annotations,
362
- # in order to properly handle the scenario where existing annotations have been deleted in
363
- # Label Studio.
364
- annotations_column: task[ANNOTATIONS_COLUMN] if len(task[ANNOTATIONS_COLUMN]) > 0 else None
365
- }
358
+ # Remove them from the `existing_tasks` dict so that future updates are applied correctly
359
+ for rowid in deleted_rowids:
360
+ del existing_tasks[rowid]
361
+
362
+ return SyncStatus(external_rows_deleted=len(deleted_rowids))
363
+
364
+ def __update_table_from_tasks(self, t: Table, tasks: dict[tuple, dict]) -> SyncStatus:
365
+ if ANNOTATIONS_COLUMN not in self.col_mapping.values():
366
+ return SyncStatus.empty()
367
+
368
+ annotations = {
369
+ # Replace [] by None to indicate no annotations. We do want to sync rows with no annotations,
370
+ # in order to properly handle the scenario where existing annotations have been deleted in
371
+ # Label Studio.
372
+ tuple(task['meta']['rowid']): task[ANNOTATIONS_COLUMN] if len(task[ANNOTATIONS_COLUMN]) > 0 else None
366
373
  for task in tasks.values()
367
- ]
374
+ }
375
+
376
+ local_annotations_col = next(k for k, v in self.col_mapping.items() if v == ANNOTATIONS_COLUMN)
377
+
378
+ # Prune the annotations down to just the ones that have actually changed.
379
+ rows = t.select(t[local_annotations_col.name])
380
+ for row in rows._exec():
381
+ assert len(row.vals) == 1
382
+ if row.rowid in annotations and annotations[row.rowid] == row[0]:
383
+ del annotations[row.rowid]
384
+
385
+ # Apply updates
386
+ updates = [{'_rowid': rowid, local_annotations_col.name: ann} for rowid, ann in annotations.items()]
368
387
  if len(updates) > 0:
369
388
  _logger.info(
370
- f'Updating table `{t.get_name()}`, column `{annotations_column}` with {len(updates)} total annotations.'
389
+ f'Updating table `{t.get_name()}`, column `{local_annotations_col.name}` with {len(updates)} total annotations.'
371
390
  )
372
- t.batch_update(updates)
373
- annotations_count = sum(len(task[ANNOTATIONS_COLUMN]) for task in tasks.values())
374
- print(f'Synced {annotations_count} annotation(s) from {len(updates)} existing task(s) in {self}.')
391
+ # batch_update currently doesn't propagate from views to base tables. As a workaround, we call
392
+ # batch_update on the actual ancestor table that holds the annotations column.
393
+ # TODO(aaron-siegel): Simplify this once propagation is properly implemented in batch_update
394
+ ancestor = t
395
+ while local_annotations_col not in ancestor._tbl_version.cols:
396
+ assert ancestor.base is not None
397
+ ancestor = ancestor.base
398
+ update_status = ancestor.batch_update(updates)
399
+ print(f'Updated annotation(s) from {len(updates)} task(s) in {self}.')
400
+ return SyncStatus(pxt_rows_updated=update_status.num_rows, num_excs=update_status.num_excs)
401
+ else:
402
+ return SyncStatus.empty()
375
403
 
376
- def to_dict(self) -> dict[str, Any]:
377
- return {'project_id': self.project_id, 'media_import_method': self.media_import_method}
404
+ def as_dict(self) -> dict[str, Any]:
405
+ return {
406
+ 'name': self.name,
407
+ 'project_id': self.project_id,
408
+ 'media_import_method': self.media_import_method,
409
+ 'col_mapping': [[self._column_as_dict(k), v] for k, v in self.col_mapping.items()],
410
+ 'stored_proxies': [[self._column_as_dict(k), self._column_as_dict(v)] for k, v in self.stored_proxies.items()]
411
+ }
378
412
 
379
413
  @classmethod
380
414
  def from_dict(cls, md: dict[str, Any]) -> 'LabelStudioProject':
381
- return LabelStudioProject(md['project_id'], md['media_import_method'])
415
+ return LabelStudioProject(
416
+ md['name'],
417
+ md['project_id'],
418
+ md['media_import_method'],
419
+ {cls._column_from_dict(entry[0]): entry[1] for entry in md['col_mapping']},
420
+ {cls._column_from_dict(entry[0]): cls._column_from_dict(entry[1]) for entry in md['stored_proxies']}
421
+ )
382
422
 
383
423
  def __repr__(self) -> str:
384
424
  name = self.project.get_params()['title']
@@ -394,27 +434,32 @@ class LabelStudioProject(Remote):
394
434
  if root.tag.lower() != 'view':
395
435
  raise excs.Error('Root of Label Studio config must be a `View`')
396
436
  config = _LabelStudioConfig(
397
- data_keys=dict(cls.__parse_data_keys_config(root)),
398
- rectangle_labels=dict(cls.__parse_rectangle_labels_config(root))
437
+ data_keys=cls.__parse_data_keys_config(root),
438
+ rectangle_labels=cls.__parse_rectangle_labels_config(root)
399
439
  )
400
440
  config.validate()
401
441
  return config
402
442
 
403
443
  @classmethod
404
- def __parse_data_keys_config(cls, root: ElementTree.Element) -> Iterator[tuple[str, '_DataKey']]:
444
+ def __parse_data_keys_config(cls, root: ElementTree.Element) -> dict[str, '_DataKey']:
445
+ """Parses the data keys from a Label Studio XML config."""
446
+ config: dict[str, '_DataKey'] = {}
405
447
  for element in root:
406
448
  if 'value' in element.attrib and element.attrib['value'][0] == '$':
407
- remote_col_name = element.attrib['value'][1:]
408
- data_key_name = element.attrib.get('name')
409
- element_type = _LS_TAG_MAP.get(element.tag.lower())
410
- if element_type is None:
449
+ external_col_name = element.attrib['value'][1:]
450
+ name = element.attrib.get('name')
451
+ column_type = _LS_TAG_MAP.get(element.tag.lower())
452
+ if column_type is None:
411
453
  raise excs.Error(
412
- f'Unsupported Label Studio data type: `{element.tag}` (in data key `{remote_col_name}`)'
454
+ f'Unsupported Label Studio data type: `{element.tag}` (in data key `{external_col_name}`)'
413
455
  )
414
- yield remote_col_name, _DataKey(data_key_name, element_type)
456
+ config[external_col_name] = _DataKey(name=name, column_type=column_type)
457
+ return config
415
458
 
416
459
  @classmethod
417
- def __parse_rectangle_labels_config(cls, root: ElementTree.Element) -> Iterator[tuple[str, '_RectangleLabel']]:
460
+ def __parse_rectangle_labels_config(cls, root: ElementTree.Element) -> dict[str, '_RectangleLabel']:
461
+ """Parses the RectangleLabels from a Label Studio XML config."""
462
+ config: dict[str, '_RectangleLabel'] = {}
418
463
  for element in root:
419
464
  if element.tag.lower() == 'rectanglelabels':
420
465
  name = element.attrib['name']
@@ -426,7 +471,8 @@ class LabelStudioProject(Remote):
426
471
  for label in labels:
427
472
  if label not in coco.COCO_2017_CATEGORIES.values():
428
473
  raise excs.Error(f'Label in `rectanglelabels` config is not a valid COCO object name: {label}')
429
- yield name, _RectangleLabel(to_name=to_name, labels=labels)
474
+ config[name] = _RectangleLabel(to_name=to_name, labels=labels)
475
+ return config
430
476
 
431
477
  @classmethod
432
478
  def __coco_to_predictions(
@@ -481,6 +527,73 @@ class LabelStudioProject(Remote):
481
527
  def __hash__(self) -> int:
482
528
  return hash(self.project_id)
483
529
 
530
+ @classmethod
531
+ def create(
532
+ cls,
533
+ t: Table,
534
+ label_config: str,
535
+ name: Optional[str],
536
+ title: Optional[str],
537
+ media_import_method: Literal['post', 'file', 'url'],
538
+ col_mapping: Optional[dict[str, str]],
539
+ **kwargs: Any
540
+ ) -> 'LabelStudioProject':
541
+ """
542
+ Creates a new Label Studio project, using the Label Studio client configured in Pixeltable.
543
+ """
544
+ # Check that the config is valid before creating the project
545
+ config = cls.__parse_project_config(label_config)
546
+
547
+ if name is None:
548
+ # Create a default name that's unique to the table
549
+ all_stores = t.external_stores
550
+ n = 0
551
+ while f'ls_project_{n}' in all_stores:
552
+ n += 1
553
+ name = f'ls_project_{n}'
554
+
555
+ if title is None:
556
+ # `title` defaults to table name
557
+ title = t.get_name()
558
+
559
+ # Create a column to hold the annotations, if one does not yet exist
560
+ if col_mapping is None or ANNOTATIONS_COLUMN in col_mapping.values():
561
+ if col_mapping is None:
562
+ local_annotations_column = ANNOTATIONS_COLUMN
563
+ else:
564
+ local_annotations_column = next(k for k, v in col_mapping.items() if v == ANNOTATIONS_COLUMN)
565
+ if local_annotations_column not in t.column_names():
566
+ t[local_annotations_column] = pxt.JsonType(nullable=True)
567
+
568
+ resolved_col_mapping = cls.validate_columns(
569
+ t, config.export_columns, {ANNOTATIONS_COLUMN: pxt.JsonType(nullable=True)}, col_mapping)
570
+
571
+ # Perform some additional validation
572
+ if media_import_method == 'post' and len(config.data_keys) > 1:
573
+ raise excs.Error('`media_import_method` cannot be `post` if there is more than one data key')
574
+
575
+ project = _label_studio_client().start_project(title=title, label_config=label_config, **kwargs)
576
+
577
+ if media_import_method == 'file':
578
+ # We need to set up a local storage connection to receive media files
579
+ os.environ['LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT'] = str(env.Env.get().home)
580
+ try:
581
+ project.connect_local_import_storage(local_store_path=str(env.Env.get().media_dir))
582
+ except HTTPError as exc:
583
+ if exc.errno == 400:
584
+ response: dict = json.loads(exc.response.text)
585
+ if 'validation_errors' in response and 'non_field_errors' in response['validation_errors'] \
586
+ and 'LOCAL_FILES_SERVING_ENABLED' in response['validation_errors']['non_field_errors'][0]:
587
+ raise excs.Error(
588
+ '`media_import_method` is set to `file`, but your Label Studio server is not configured '
589
+ 'for local file storage.\nPlease set the `LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED` '
590
+ 'environment variable to `true` in the environment where your Label Studio server is running.'
591
+ ) from exc
592
+ raise # Handle any other exception type normally
593
+
594
+ project_id = project.get_params()['id']
595
+ return LabelStudioProject(name, project_id, media_import_method, resolved_col_mapping)
596
+
484
597
 
485
598
  @dataclass(frozen=True)
486
599
  class _DataKey:
@@ -500,7 +613,7 @@ class _LabelStudioConfig:
500
613
  rectangle_labels: dict[str, _RectangleLabel]
501
614
 
502
615
  def validate(self) -> None:
503
- data_key_names = set(key.name for key in self.data_keys.values() if key is not None)
616
+ data_key_names = set(key.name for key in self.data_keys.values() if key.name is not None)
504
617
  for name, rl in self.rectangle_labels.items():
505
618
  if rl.to_name not in data_key_names:
506
619
  raise excs.Error(
pixeltable/io/parquet.py CHANGED
@@ -63,7 +63,7 @@ def save_parquet(df: pxt.DataFrame, dest_path: Path, partition_size_bytes: int =
63
63
  # store the changes atomically
64
64
  with transactional_directory(dest_path) as temp_path:
65
65
  # dump metadata json file so we can inspect what was the source of the parquet file later on.
66
- json.dump(df._as_dict(), (temp_path / '.pixeltable.json').open('w')) # pylint: disable=protected-access
66
+ json.dump(df.as_dict(), (temp_path / '.pixeltable.json').open('w')) # pylint: disable=protected-access
67
67
  json.dump(type_dict, (temp_path / '.pixeltable.column_types.json').open('w')) # keep type metadata
68
68
 
69
69
  batch_num = 0