labelr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- labelr/apps/datasets.py +196 -14
- labelr/apps/directus.py +212 -0
- labelr/apps/google_batch.py +46 -1
- labelr/apps/label_studio.py +261 -64
- labelr/apps/typer_description.py +2 -0
- labelr/check.py +68 -7
- labelr/config.py +57 -1
- labelr/export/__init__.py +0 -0
- labelr/export/classification.py +114 -0
- labelr/export/common.py +42 -0
- labelr/export/llm.py +91 -0
- labelr/{export.py → export/object_detection.py} +97 -217
- labelr/google_genai.py +9 -3
- labelr/main.py +16 -0
- labelr/sample/__init__.py +0 -0
- labelr/sample/classification.py +17 -0
- labelr/sample/common.py +14 -0
- labelr/sample/llm.py +75 -0
- labelr/{sample.py → sample/object_detection.py} +38 -68
- labelr/utils.py +55 -5
- labelr-0.11.0.dist-info/METADATA +230 -0
- labelr-0.11.0.dist-info/RECORD +38 -0
- {labelr-0.9.0.dist-info → labelr-0.11.0.dist-info}/WHEEL +1 -1
- labelr-0.9.0.dist-info/METADATA +0 -159
- labelr-0.9.0.dist-info/RECORD +0 -28
- {labelr-0.9.0.dist-info → labelr-0.11.0.dist-info}/entry_points.txt +0 -0
- {labelr-0.9.0.dist-info → labelr-0.11.0.dist-info}/licenses/LICENSE +0 -0
- {labelr-0.9.0.dist-info → labelr-0.11.0.dist-info}/top_level.txt +0 -0
labelr/apps/datasets.py
CHANGED
|
@@ -12,9 +12,14 @@ import typer
|
|
|
12
12
|
from openfoodfacts import Flavor
|
|
13
13
|
from openfoodfacts.utils import get_logger
|
|
14
14
|
|
|
15
|
-
from labelr.export import export_from_ultralytics_to_hf
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
from labelr.export.common import export_from_ultralytics_to_hf
|
|
16
|
+
from labelr.export.object_detection import (
|
|
17
|
+
export_from_ls_to_hf_object_detection,
|
|
18
|
+
export_from_ls_to_ultralytics_object_detection,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from . import typer_description
|
|
22
|
+
from ..config import config
|
|
18
23
|
from ..types import ExportDestination, ExportSource, TaskType
|
|
19
24
|
|
|
20
25
|
app = typer.Typer()
|
|
@@ -99,7 +104,9 @@ def convert_object_detection_dataset(
|
|
|
99
104
|
Studio format, and save it to a JSON file."""
|
|
100
105
|
from datasets import load_dataset
|
|
101
106
|
|
|
102
|
-
from labelr.sample import
|
|
107
|
+
from labelr.sample.object_detection import (
|
|
108
|
+
format_object_detection_sample_from_hf_to_ls,
|
|
109
|
+
)
|
|
103
110
|
|
|
104
111
|
logger.info("Loading dataset: %s", repo_id)
|
|
105
112
|
ds = load_dataset(repo_id)
|
|
@@ -119,7 +126,9 @@ def convert_object_detection_dataset(
|
|
|
119
126
|
def export(
|
|
120
127
|
from_: Annotated[ExportSource, typer.Option("--from", help="Input source to use")],
|
|
121
128
|
to: Annotated[ExportDestination, typer.Option(help="Where to export the data")],
|
|
122
|
-
api_key: Annotated[
|
|
129
|
+
api_key: Annotated[
|
|
130
|
+
str | None, typer.Option(help=typer_description.LABEL_STUDIO_API_KEY)
|
|
131
|
+
] = config.label_studio_api_key,
|
|
123
132
|
task_type: Annotated[
|
|
124
133
|
TaskType, typer.Option(help="Type of task to export")
|
|
125
134
|
] = TaskType.object_detection,
|
|
@@ -136,7 +145,16 @@ def export(
|
|
|
136
145
|
project_id: Annotated[
|
|
137
146
|
Optional[int], typer.Option(help="Label Studio Project ID")
|
|
138
147
|
] = None,
|
|
139
|
-
|
|
148
|
+
view_id: Annotated[
|
|
149
|
+
int | None,
|
|
150
|
+
typer.Option(
|
|
151
|
+
help="ID of the Label Studio view, if any. This option is useful "
|
|
152
|
+
"to filter the task to export."
|
|
153
|
+
),
|
|
154
|
+
] = None,
|
|
155
|
+
label_studio_url: Annotated[
|
|
156
|
+
str, typer.Option(help=typer_description.LABEL_STUDIO_URL)
|
|
157
|
+
] = config.label_studio_url,
|
|
140
158
|
output_dir: Annotated[
|
|
141
159
|
Optional[Path],
|
|
142
160
|
typer.Option(
|
|
@@ -157,11 +175,15 @@ def export(
|
|
|
157
175
|
is_openfoodfacts_dataset: Annotated[
|
|
158
176
|
bool,
|
|
159
177
|
typer.Option(
|
|
160
|
-
help="Whether the Ultralytics dataset is an
|
|
161
|
-
"for Ultralytics source. This is used
|
|
162
|
-
"each image name
|
|
178
|
+
help="Whether the Ultralytics dataset is an Open Food Facts dataset, only "
|
|
179
|
+
"for Ultralytics source. This is used:\n"
|
|
180
|
+
"- to generate the correct image URLs from each image name, when exporting "
|
|
181
|
+
"from Ultralytics to Hugging Face Datasets.\n"
|
|
182
|
+
"- to include additional metadata fields specific to Open Food Facts "
|
|
183
|
+
"(`barcode` and `off_image_id`) when exporting from Label Studio to "
|
|
184
|
+
"Hugging Face Datasets."
|
|
163
185
|
),
|
|
164
|
-
] =
|
|
186
|
+
] = False,
|
|
165
187
|
openfoodfacts_flavor: Annotated[
|
|
166
188
|
Flavor,
|
|
167
189
|
typer.Option(
|
|
@@ -175,9 +197,18 @@ def export(
|
|
|
175
197
|
float,
|
|
176
198
|
typer.Option(
|
|
177
199
|
help="Train ratio for splitting the dataset, if the split name is not "
|
|
178
|
-
"provided
|
|
200
|
+
"provided. Only used if the source is Label Studio and the destination "
|
|
201
|
+
"is Ultralytics."
|
|
179
202
|
),
|
|
180
203
|
] = 0.8,
|
|
204
|
+
image_max_size: Annotated[
|
|
205
|
+
int | None,
|
|
206
|
+
typer.Option(
|
|
207
|
+
help="Maximum size (in pixels) for the images. If None, no resizing is performed."
|
|
208
|
+
"Otherwise, the longest side of the image will be resized to this value, "
|
|
209
|
+
"keeping the aspect ratio."
|
|
210
|
+
),
|
|
211
|
+
] = None,
|
|
181
212
|
error_raise: Annotated[
|
|
182
213
|
bool,
|
|
183
214
|
typer.Option(
|
|
@@ -207,10 +238,8 @@ def export(
|
|
|
207
238
|
local files (ultralytics format)."""
|
|
208
239
|
from label_studio_sdk.client import LabelStudio
|
|
209
240
|
|
|
210
|
-
from labelr.export import (
|
|
241
|
+
from labelr.export.object_detection import (
|
|
211
242
|
export_from_hf_to_ultralytics_object_detection,
|
|
212
|
-
export_from_ls_to_hf_object_detection,
|
|
213
|
-
export_from_ls_to_ultralytics_object_detection,
|
|
214
243
|
)
|
|
215
244
|
|
|
216
245
|
if (to == ExportDestination.hf or from_ == ExportSource.hf) and repo_id is None:
|
|
@@ -256,9 +285,12 @@ def export(
|
|
|
256
285
|
repo_id=repo_id,
|
|
257
286
|
label_names=typing.cast(list[str], label_names_list),
|
|
258
287
|
project_id=typing.cast(int, project_id),
|
|
288
|
+
is_openfoodfacts_dataset=is_openfoodfacts_dataset,
|
|
259
289
|
merge_labels=merge_labels,
|
|
260
290
|
use_aws_cache=use_aws_cache,
|
|
261
291
|
revision=revision,
|
|
292
|
+
view_id=view_id,
|
|
293
|
+
image_max_size=image_max_size,
|
|
262
294
|
)
|
|
263
295
|
elif to == ExportDestination.ultralytics:
|
|
264
296
|
export_from_ls_to_ultralytics_object_detection(
|
|
@@ -270,6 +302,8 @@ def export(
|
|
|
270
302
|
error_raise=error_raise,
|
|
271
303
|
merge_labels=merge_labels,
|
|
272
304
|
use_aws_cache=use_aws_cache,
|
|
305
|
+
view_id=view_id,
|
|
306
|
+
image_max_size=image_max_size,
|
|
273
307
|
)
|
|
274
308
|
|
|
275
309
|
elif from_ == ExportSource.hf:
|
|
@@ -285,6 +319,7 @@ def export(
|
|
|
285
319
|
error_raise=error_raise,
|
|
286
320
|
use_aws_cache=use_aws_cache,
|
|
287
321
|
revision=revision,
|
|
322
|
+
image_max_size=image_max_size,
|
|
288
323
|
)
|
|
289
324
|
else:
|
|
290
325
|
raise typer.BadParameter("Unsupported export format")
|
|
@@ -303,3 +338,150 @@ def export(
|
|
|
303
338
|
is_openfoodfacts_dataset=is_openfoodfacts_dataset,
|
|
304
339
|
openfoodfacts_flavor=openfoodfacts_flavor,
|
|
305
340
|
)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
@app.command()
|
|
344
|
+
def export_llm_ds(
|
|
345
|
+
dataset_path: Annotated[
|
|
346
|
+
Path, typer.Option(..., help="Path to the JSONL dataset file")
|
|
347
|
+
],
|
|
348
|
+
repo_id: Annotated[
|
|
349
|
+
str, typer.Option(..., help="Hugging Face Datasets repository ID to export to")
|
|
350
|
+
],
|
|
351
|
+
split: Annotated[str, typer.Option(..., help="Dataset split to export")],
|
|
352
|
+
revision: Annotated[
|
|
353
|
+
str,
|
|
354
|
+
typer.Option(
|
|
355
|
+
help="Revision (branch, tag or commit) for the Hugging Face Datasets repository."
|
|
356
|
+
),
|
|
357
|
+
] = "main",
|
|
358
|
+
tmp_dir: Annotated[
|
|
359
|
+
Path | None,
|
|
360
|
+
typer.Option(
|
|
361
|
+
help="Path to the temporary directory used to store intermediate sample files "
|
|
362
|
+
"created when building the HF dataset.",
|
|
363
|
+
),
|
|
364
|
+
] = None,
|
|
365
|
+
image_max_size: Annotated[
|
|
366
|
+
int | None,
|
|
367
|
+
typer.Option(
|
|
368
|
+
help="Maximum size (in pixels) for the images. If None, no resizing is performed.",
|
|
369
|
+
),
|
|
370
|
+
] = None,
|
|
371
|
+
):
|
|
372
|
+
"""Export LLM image extraction dataset with images only to Hugging Face
|
|
373
|
+
Datasets.
|
|
374
|
+
"""
|
|
375
|
+
from labelr.export.llm import export_to_hf_llm_image_extraction
|
|
376
|
+
from labelr.sample.llm import load_llm_image_extraction_dataset_from_jsonl
|
|
377
|
+
|
|
378
|
+
sample_iter = load_llm_image_extraction_dataset_from_jsonl(
|
|
379
|
+
dataset_path=dataset_path
|
|
380
|
+
)
|
|
381
|
+
export_to_hf_llm_image_extraction(
|
|
382
|
+
sample_iter,
|
|
383
|
+
split=split,
|
|
384
|
+
repo_id=repo_id,
|
|
385
|
+
revision=revision,
|
|
386
|
+
tmp_dir=tmp_dir,
|
|
387
|
+
image_max_size=image_max_size,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@app.command()
|
|
392
|
+
def update_llm_ds(
|
|
393
|
+
dataset_path: Annotated[
|
|
394
|
+
Path, typer.Option(help="Path to the JSONL containing the updates.")
|
|
395
|
+
],
|
|
396
|
+
repo_id: Annotated[
|
|
397
|
+
str, typer.Option(help="Hugging Face Datasets repository ID to update")
|
|
398
|
+
],
|
|
399
|
+
split: Annotated[str, typer.Option(help="Dataset split to use")],
|
|
400
|
+
revision: Annotated[
|
|
401
|
+
str,
|
|
402
|
+
typer.Option(
|
|
403
|
+
help="Revision (branch, tag or commit) to use when pushing the new version "
|
|
404
|
+
"of the Hugging Face Dataset."
|
|
405
|
+
),
|
|
406
|
+
] = "main",
|
|
407
|
+
tmp_dir: Annotated[
|
|
408
|
+
Path | None,
|
|
409
|
+
typer.Option(
|
|
410
|
+
help="Path to a temporary directory to use for image processing",
|
|
411
|
+
),
|
|
412
|
+
] = None,
|
|
413
|
+
show_diff: Annotated[
|
|
414
|
+
bool,
|
|
415
|
+
typer.Option(
|
|
416
|
+
help="Show the differences between the original sample and the update. If "
|
|
417
|
+
"True, the updated dataset is not pushed to the Hub. Useful to review the "
|
|
418
|
+
"updates before applying them.",
|
|
419
|
+
),
|
|
420
|
+
] = False,
|
|
421
|
+
):
|
|
422
|
+
"""Update an existing LLM image extraction dataset, by updating the
|
|
423
|
+
`output` field of each sample in the dataset.
|
|
424
|
+
|
|
425
|
+
The `--dataset_path` JSONL file should contain items with two fields:
|
|
426
|
+
|
|
427
|
+
- `image_id`: The image ID of the sample to update in the Hugging Face
|
|
428
|
+
dataset.
|
|
429
|
+
- `output`: The new output data to set for the sample.
|
|
430
|
+
"""
|
|
431
|
+
import sys
|
|
432
|
+
from difflib import Differ
|
|
433
|
+
|
|
434
|
+
import orjson
|
|
435
|
+
from datasets import load_dataset
|
|
436
|
+
from diskcache import Cache
|
|
437
|
+
|
|
438
|
+
dataset = load_dataset(repo_id, split=split)
|
|
439
|
+
|
|
440
|
+
# Populate cache with the updates
|
|
441
|
+
cache = Cache(directory=tmp_dir or None)
|
|
442
|
+
with dataset_path.open("r") as f:
|
|
443
|
+
for line in map(orjson.loads, f):
|
|
444
|
+
if "image_id" not in line or "output" not in line:
|
|
445
|
+
raise ValueError(
|
|
446
|
+
"Each item in the update JSONL file must contain `image_id` and `output` fields"
|
|
447
|
+
)
|
|
448
|
+
image_id = line["image_id"]
|
|
449
|
+
output = line["output"]
|
|
450
|
+
|
|
451
|
+
if not isinstance(output, str):
|
|
452
|
+
output = orjson.dumps(output).decode("utf-8")
|
|
453
|
+
|
|
454
|
+
cache[image_id] = output
|
|
455
|
+
|
|
456
|
+
def apply_updates(sample):
|
|
457
|
+
image_id = sample["image_id"]
|
|
458
|
+
if image_id in cache:
|
|
459
|
+
cached_item = cache[image_id]
|
|
460
|
+
sample["output"] = cached_item
|
|
461
|
+
return sample
|
|
462
|
+
|
|
463
|
+
if show_diff:
|
|
464
|
+
differ = Differ()
|
|
465
|
+
for sample in dataset:
|
|
466
|
+
image_id = sample["image_id"]
|
|
467
|
+
if image_id in cache:
|
|
468
|
+
cached_item = orjson.loads(cache[image_id])
|
|
469
|
+
original_item = orjson.loads(sample["output"])
|
|
470
|
+
cached_item_str = orjson.dumps(
|
|
471
|
+
cached_item, option=orjson.OPT_INDENT_2
|
|
472
|
+
).decode("utf8")
|
|
473
|
+
original_item_str = orjson.dumps(
|
|
474
|
+
original_item, option=orjson.OPT_INDENT_2
|
|
475
|
+
).decode("utf8")
|
|
476
|
+
diff = list(
|
|
477
|
+
differ.compare(
|
|
478
|
+
original_item_str.splitlines(keepends=True),
|
|
479
|
+
cached_item_str.splitlines(keepends=True),
|
|
480
|
+
)
|
|
481
|
+
)
|
|
482
|
+
sys.stdout.writelines(diff)
|
|
483
|
+
sys.stdout.write("\n" + "-" * 30 + "\n")
|
|
484
|
+
|
|
485
|
+
else:
|
|
486
|
+
updated_dataset = dataset.map(apply_updates, batched=False)
|
|
487
|
+
updated_dataset.push_to_hub(repo_id, split=split, revision=revision)
|
labelr/apps/directus.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
import requests
|
|
5
|
+
import typer
|
|
6
|
+
|
|
7
|
+
app = typer.Typer()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
DEFAULT_DIRECTUS_URL = "http://localhost:8055"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _list_endpoint_iter(
|
|
14
|
+
url: str,
|
|
15
|
+
session: requests.Session,
|
|
16
|
+
page_size: int,
|
|
17
|
+
method: str = "GET",
|
|
18
|
+
list_field: str | None = "data",
|
|
19
|
+
**kwargs,
|
|
20
|
+
):
|
|
21
|
+
"""Iterate over paginated Directus endpoint.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
url (str): URL of the Directus endpoint.
|
|
25
|
+
session (requests.Session): Requests session to use for making HTTP
|
|
26
|
+
requests.
|
|
27
|
+
page_size (int): Number of items to fetch per page.
|
|
28
|
+
method (str, optional): HTTP method to use. Defaults to "GET".
|
|
29
|
+
list_field (str | None, optional): Field in the response JSON that
|
|
30
|
+
contains the list of items. If None, the entire response is used as
|
|
31
|
+
the list. Defaults to "data".
|
|
32
|
+
**kwargs: Additional keyword arguments to pass to the requests method.
|
|
33
|
+
Yields:
|
|
34
|
+
dict: Items from the Directus endpoint.
|
|
35
|
+
"""
|
|
36
|
+
page = 0
|
|
37
|
+
next_page = True
|
|
38
|
+
params = kwargs.pop("params", {})
|
|
39
|
+
|
|
40
|
+
while next_page:
|
|
41
|
+
params["offset"] = page * page_size
|
|
42
|
+
params["limit"] = page_size
|
|
43
|
+
r = session.request(method=method, url=url, params=params, **kwargs)
|
|
44
|
+
r.raise_for_status()
|
|
45
|
+
response = r.json()
|
|
46
|
+
items = response[list_field] if list_field else response
|
|
47
|
+
if len(items) > 0:
|
|
48
|
+
yield from items
|
|
49
|
+
else:
|
|
50
|
+
next_page = False
|
|
51
|
+
page += 1
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def iter_items(
|
|
55
|
+
collection_name: str,
|
|
56
|
+
url: str,
|
|
57
|
+
session: requests.Session,
|
|
58
|
+
page_size: int = 50,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
"""Iterate over items in a Directus collection.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
collection_name (str): Name of the Directus collection.
|
|
65
|
+
url (str): Base URL of the Directus server.
|
|
66
|
+
session (requests.Session): Requests session to use for making HTTP
|
|
67
|
+
requests.
|
|
68
|
+
page_size (int, optional): Number of items to fetch per page. Defaults
|
|
69
|
+
to 50.
|
|
70
|
+
**kwargs: Additional keyword arguments to pass to the requests method.
|
|
71
|
+
Yields:
|
|
72
|
+
dict: Items from the Directus collection.
|
|
73
|
+
"""
|
|
74
|
+
yield from _list_endpoint_iter(
|
|
75
|
+
url=f"{url}/items/{collection_name}",
|
|
76
|
+
session=session,
|
|
77
|
+
page_size=page_size,
|
|
78
|
+
**kwargs,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@app.command()
|
|
83
|
+
def upload_data(
|
|
84
|
+
dataset_path: Annotated[
|
|
85
|
+
Path,
|
|
86
|
+
typer.Option(
|
|
87
|
+
help="Path to the dataset JSONL file to upload from.",
|
|
88
|
+
file_okay=True,
|
|
89
|
+
dir_okay=False,
|
|
90
|
+
readable=True,
|
|
91
|
+
),
|
|
92
|
+
],
|
|
93
|
+
collection: Annotated[
|
|
94
|
+
str, typer.Option(help="Name of the collection to upload the items to.")
|
|
95
|
+
],
|
|
96
|
+
directus_url: Annotated[
|
|
97
|
+
str,
|
|
98
|
+
typer.Option(
|
|
99
|
+
help="Base URL of the Directus server.",
|
|
100
|
+
),
|
|
101
|
+
] = DEFAULT_DIRECTUS_URL,
|
|
102
|
+
):
|
|
103
|
+
"""Upload data to a Directus collection."""
|
|
104
|
+
import orjson
|
|
105
|
+
import requests
|
|
106
|
+
import tqdm
|
|
107
|
+
|
|
108
|
+
session = requests.Session()
|
|
109
|
+
|
|
110
|
+
with dataset_path.open("r") as f:
|
|
111
|
+
for item in tqdm.tqdm(map(orjson.loads, f), desc="items"):
|
|
112
|
+
r = session.post(
|
|
113
|
+
f"{directus_url}/items/{collection}",
|
|
114
|
+
json=item,
|
|
115
|
+
)
|
|
116
|
+
print(r.json())
|
|
117
|
+
r.raise_for_status()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@app.command()
|
|
121
|
+
def update_items(
|
|
122
|
+
collection: Annotated[
|
|
123
|
+
str, typer.Option(help="Name of the collection to upload the items to.")
|
|
124
|
+
],
|
|
125
|
+
directus_url: Annotated[
|
|
126
|
+
str,
|
|
127
|
+
typer.Option(
|
|
128
|
+
help="Base URL of the Directus server.",
|
|
129
|
+
),
|
|
130
|
+
] = DEFAULT_DIRECTUS_URL,
|
|
131
|
+
sort: Annotated[
|
|
132
|
+
str | None,
|
|
133
|
+
typer.Option(help="The field to sort items by, defaults to None (no sorting)."),
|
|
134
|
+
] = None,
|
|
135
|
+
skip: Annotated[
|
|
136
|
+
int, typer.Option(help="Number of items to skip, defaults to 0.")
|
|
137
|
+
] = 0,
|
|
138
|
+
):
|
|
139
|
+
"""Update items in a Directus collection.
|
|
140
|
+
|
|
141
|
+
**Warning**: This command requires you to implement the processing
|
|
142
|
+
function inside the command. It is provided as a template for batch
|
|
143
|
+
updating items in a Directus collection.
|
|
144
|
+
"""
|
|
145
|
+
import requests
|
|
146
|
+
import tqdm
|
|
147
|
+
|
|
148
|
+
session = requests.Session()
|
|
149
|
+
|
|
150
|
+
params = {} if sort is None else {"sort[]": sort}
|
|
151
|
+
for i, item in tqdm.tqdm(
|
|
152
|
+
enumerate(
|
|
153
|
+
iter_items(
|
|
154
|
+
collection_name=collection,
|
|
155
|
+
url=directus_url,
|
|
156
|
+
session=session,
|
|
157
|
+
params=params,
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
):
|
|
161
|
+
if i < skip:
|
|
162
|
+
typer.echo(f"Skipping item {i}")
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
item_id = item["id"]
|
|
166
|
+
# Implement your processing function here
|
|
167
|
+
# It should return a dict with the fields to update only
|
|
168
|
+
# If no update is needed, it should return None
|
|
169
|
+
patch_item = None
|
|
170
|
+
|
|
171
|
+
if patch_item is not None:
|
|
172
|
+
r = session.patch(
|
|
173
|
+
f"{directus_url}/items/{collection}/{item_id}",
|
|
174
|
+
json=patch_item,
|
|
175
|
+
)
|
|
176
|
+
r.raise_for_status()
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@app.command()
|
|
180
|
+
def export_data(
|
|
181
|
+
output_path: Annotated[
|
|
182
|
+
Path, typer.Option(help="Path to the file to export to.", allow_dash=True)
|
|
183
|
+
],
|
|
184
|
+
collection: Annotated[
|
|
185
|
+
str, typer.Option(help="Name of the collection to upload the items to.")
|
|
186
|
+
],
|
|
187
|
+
directus_url: Annotated[
|
|
188
|
+
str,
|
|
189
|
+
typer.Option(
|
|
190
|
+
help="Base URL of the Directus server.",
|
|
191
|
+
),
|
|
192
|
+
] = DEFAULT_DIRECTUS_URL,
|
|
193
|
+
):
|
|
194
|
+
"""Export a directus collection to a JSONL file."""
|
|
195
|
+
import sys
|
|
196
|
+
|
|
197
|
+
import orjson
|
|
198
|
+
import requests
|
|
199
|
+
import tqdm
|
|
200
|
+
|
|
201
|
+
session = requests.Session()
|
|
202
|
+
|
|
203
|
+
f = sys.stdout if output_path.as_posix() == "-" else output_path.open("w")
|
|
204
|
+
with f:
|
|
205
|
+
for item in tqdm.tqdm(
|
|
206
|
+
iter_items(
|
|
207
|
+
collection_name=collection,
|
|
208
|
+
url=directus_url,
|
|
209
|
+
session=session,
|
|
210
|
+
)
|
|
211
|
+
):
|
|
212
|
+
f.write(orjson.dumps(item).decode("utf-8") + "\n")
|
labelr/apps/google_batch.py
CHANGED
|
@@ -7,6 +7,7 @@ import typer
|
|
|
7
7
|
from google.genai.types import JSONSchema as GoogleJSONSchema
|
|
8
8
|
from google.genai.types import Schema as GoogleSchema
|
|
9
9
|
from openfoodfacts import Flavor
|
|
10
|
+
from openfoodfacts.types import JSONType
|
|
10
11
|
from pydantic import BaseModel
|
|
11
12
|
|
|
12
13
|
from labelr.google_genai import generate_batch_dataset, launch_batch_job
|
|
@@ -14,6 +15,40 @@ from labelr.google_genai import generate_batch_dataset, launch_batch_job
|
|
|
14
15
|
app = typer.Typer()
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
def _check_json_schema(item: JSONType) -> None:
|
|
19
|
+
if item.get("type") == "object":
|
|
20
|
+
required_fields = item.get("required", [])
|
|
21
|
+
all_fields = item.get("properties", [])
|
|
22
|
+
diff = set(all_fields) - set(required_fields)
|
|
23
|
+
if diff:
|
|
24
|
+
raise ValueError(
|
|
25
|
+
f"fields '{diff}' must be marked as required in the JSONSchema. "
|
|
26
|
+
"All fields with type 'object' must be required."
|
|
27
|
+
)
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def check_json_schema(json_schema: JSONType) -> None:
|
|
32
|
+
"""Check that for all `object`s, all fields are marked as required.
|
|
33
|
+
|
|
34
|
+
This is important to check, as otherwise the structured generation
|
|
35
|
+
backend may prevent the model to generate these fields.
|
|
36
|
+
This is the case as of vLLM 0.13 and xgrammars as backend.
|
|
37
|
+
|
|
38
|
+
To prevent this, we ask all fields to be marked as required.
|
|
39
|
+
"""
|
|
40
|
+
stack = [json_schema]
|
|
41
|
+
|
|
42
|
+
for def_item in json_schema.get("$defs", {}).values():
|
|
43
|
+
stack.append(def_item)
|
|
44
|
+
|
|
45
|
+
while stack:
|
|
46
|
+
item = stack.pop()
|
|
47
|
+
_check_json_schema(item)
|
|
48
|
+
for sub_item in item.get("properties", {}).values():
|
|
49
|
+
stack.append(sub_item)
|
|
50
|
+
|
|
51
|
+
|
|
17
52
|
def convert_pydantic_model_to_google_schema(schema: type[BaseModel]) -> dict[str, Any]:
|
|
18
53
|
"""Google doesn't support natively OpenAPI schemas, so we convert them to
|
|
19
54
|
Google `Schema` (a subset of OpenAPI)."""
|
|
@@ -239,6 +274,12 @@ def upload_training_dataset_from_predictions(
|
|
|
239
274
|
help="Whether to raise an error on invalid samples instead of skipping them",
|
|
240
275
|
),
|
|
241
276
|
] = False,
|
|
277
|
+
image_max_size: Annotated[
|
|
278
|
+
int | None,
|
|
279
|
+
typer.Option(
|
|
280
|
+
help="Maximum size (in pixels) for the images. If None, no resizing is performed.",
|
|
281
|
+
),
|
|
282
|
+
] = None,
|
|
242
283
|
):
|
|
243
284
|
"""Upload a training dataset to a Hugging Face Datasets repository from a
|
|
244
285
|
Gemini batch prediction file."""
|
|
@@ -247,13 +288,16 @@ def upload_training_dataset_from_predictions(
|
|
|
247
288
|
import orjson
|
|
248
289
|
from huggingface_hub import HfApi
|
|
249
290
|
|
|
250
|
-
from labelr.export import export_to_hf_llm_image_extraction
|
|
291
|
+
from labelr.export.llm import export_to_hf_llm_image_extraction
|
|
251
292
|
from labelr.google_genai import generate_sample_iter
|
|
252
293
|
|
|
253
294
|
instructions = instructions_path.read_text()
|
|
254
295
|
print(f"Instructions: {instructions}")
|
|
255
296
|
json_schema = orjson.loads(json_schema_path.read_text())
|
|
256
297
|
|
|
298
|
+
# We check that all fields are marked as required
|
|
299
|
+
check_json_schema(json_schema)
|
|
300
|
+
|
|
257
301
|
api = HfApi()
|
|
258
302
|
config = {
|
|
259
303
|
"instructions": instructions,
|
|
@@ -286,4 +330,5 @@ def upload_training_dataset_from_predictions(
|
|
|
286
330
|
repo_id=repo_id,
|
|
287
331
|
revision=revision,
|
|
288
332
|
tmp_dir=tmp_dir,
|
|
333
|
+
image_max_size=image_max_size,
|
|
289
334
|
)
|