labelr 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,34 +1,23 @@
1
1
  import functools
2
2
  import logging
3
3
  import pickle
4
- import random
5
4
  import tempfile
6
5
  from pathlib import Path
7
6
 
8
7
  import datasets
9
8
  import tqdm
10
9
  from label_studio_sdk.client import LabelStudio
11
- from openfoodfacts.images import download_image, generate_image_url
12
- from openfoodfacts.types import Flavor
13
- from PIL import Image, ImageOps
10
+ from openfoodfacts.images import download_image
14
11
 
15
- from labelr.sample import (
16
- HF_DS_CLASSIFICATION_FEATURES,
12
+ from labelr.export.common import _pickle_sample_generator
13
+ from labelr.sample.object_detection import (
17
14
  HF_DS_OBJECT_DETECTION_FEATURES,
18
15
  format_object_detection_sample_to_hf,
19
16
  )
20
- from labelr.types import TaskType
21
17
 
22
18
  logger = logging.getLogger(__name__)
23
19
 
24
20
 
25
- def _pickle_sample_generator(dir: Path):
26
- """Generator that yields samples from pickles in a directory."""
27
- for pkl in dir.glob("*.pkl"):
28
- with open(pkl, "rb") as f:
29
- yield pickle.load(f)
30
-
31
-
32
21
  def export_from_ls_to_hf_object_detection(
33
22
  ls: LabelStudio,
34
23
  repo_id: str,
@@ -331,127 +320,3 @@ def export_from_hf_to_ultralytics_object_detection(
331
320
  f.write("names:\n")
332
321
  for i, category_name in enumerate(category_names):
333
322
  f.write(f" {i}: {category_name}\n")
334
-
335
-
336
- def export_from_ultralytics_to_hf(
337
- task_type: TaskType,
338
- dataset_dir: Path,
339
- repo_id: str,
340
- label_names: list[str],
341
- merge_labels: bool = False,
342
- is_openfoodfacts_dataset: bool = False,
343
- openfoodfacts_flavor: Flavor = Flavor.off,
344
- ) -> None:
345
- if task_type != TaskType.classification:
346
- raise NotImplementedError(
347
- "Only classification task is currently supported for Ultralytics to HF export"
348
- )
349
-
350
- if task_type == TaskType.classification:
351
- export_from_ultralytics_to_hf_classification(
352
- dataset_dir=dataset_dir,
353
- repo_id=repo_id,
354
- label_names=label_names,
355
- merge_labels=merge_labels,
356
- is_openfoodfacts_dataset=is_openfoodfacts_dataset,
357
- openfoodfacts_flavor=openfoodfacts_flavor,
358
- )
359
-
360
-
361
- def export_from_ultralytics_to_hf_classification(
362
- dataset_dir: Path,
363
- repo_id: str,
364
- label_names: list[str],
365
- merge_labels: bool = False,
366
- is_openfoodfacts_dataset: bool = False,
367
- openfoodfacts_flavor: Flavor = Flavor.off,
368
- ) -> None:
369
- """Export an Ultralytics classification dataset to a Hugging Face dataset.
370
-
371
- The Ultralytics dataset directory should contain 'train', 'val' and/or
372
- 'test' subdirectories, each containing subdirectories for each label.
373
-
374
- Args:
375
- dataset_dir (Path): Path to the Ultralytics dataset directory.
376
- repo_id (str): Hugging Face repository ID to push the dataset to.
377
- label_names (list[str]): List of label names.
378
- merge_labels (bool): Whether to merge all labels into a single label
379
- named 'object'.
380
- is_openfoodfacts_dataset (bool): Whether the dataset is from
381
- Open Food Facts. If True, the `off_image_id` and `image_url` will
382
- be generated automatically. `off_image_id` is extracted from the
383
- image filename.
384
- openfoodfacts_flavor (Flavor): Flavor of Open Food Facts dataset. This
385
- is ignored if `is_openfoodfacts_dataset` is False.
386
- """
387
- logger.info("Repo ID: %s, dataset_dir: %s", repo_id, dataset_dir)
388
-
389
- if not any((dataset_dir / split).is_dir() for split in ["train", "val", "test"]):
390
- raise ValueError(
391
- f"Dataset directory {dataset_dir} does not contain 'train', 'val' or 'test' subdirectories"
392
- )
393
-
394
- # Save output as pickle
395
- for split in ["train", "val", "test"]:
396
- split_dir = dataset_dir / split
397
-
398
- if not split_dir.is_dir():
399
- logger.info("Skipping missing split directory: %s", split_dir)
400
- continue
401
-
402
- with tempfile.TemporaryDirectory() as tmp_dir_str:
403
- tmp_dir = Path(tmp_dir_str)
404
- for label_dir in (d for d in split_dir.iterdir() if d.is_dir()):
405
- label_name = label_dir.name
406
- if merge_labels:
407
- label_name = "object"
408
- if label_name not in label_names:
409
- raise ValueError(
410
- "Label name %s not in provided label names (label names: %s)"
411
- % (label_name, label_names),
412
- )
413
- label_id = label_names.index(label_name)
414
-
415
- for image_path in label_dir.glob("*"):
416
- if is_openfoodfacts_dataset:
417
- image_stem_parts = image_path.stem.split("_")
418
- barcode = image_stem_parts[0]
419
- off_image_id = image_stem_parts[1]
420
- image_id = f"{barcode}_{off_image_id}"
421
- image_url = generate_image_url(
422
- barcode, off_image_id, flavor=openfoodfacts_flavor
423
- )
424
- else:
425
- image_id = image_path.stem
426
- barcode = ""
427
- off_image_id = ""
428
- image_url = ""
429
- image = Image.open(image_path)
430
- image.load()
431
-
432
- if image.mode != "RGB":
433
- image = image.convert("RGB")
434
-
435
- # Rotate image according to exif orientation using Pillow
436
- ImageOps.exif_transpose(image, in_place=True)
437
- sample = {
438
- "image_id": image_id,
439
- "image": image,
440
- "width": image.width,
441
- "height": image.height,
442
- "meta": {
443
- "barcode": barcode,
444
- "off_image_id": off_image_id,
445
- "image_url": image_url,
446
- },
447
- "category_id": label_id,
448
- "category_name": label_name,
449
- }
450
- with open(tmp_dir / f"{split}_{image_id}.pkl", "wb") as f:
451
- pickle.dump(sample, f)
452
-
453
- hf_ds = datasets.Dataset.from_generator(
454
- functools.partial(_pickle_sample_generator, tmp_dir),
455
- features=HF_DS_CLASSIFICATION_FEATURES,
456
- )
457
- hf_ds.push_to_hub(repo_id, split=split)
labelr/google_genai.py ADDED
@@ -0,0 +1,421 @@
1
+ import asyncio
2
+ import mimetypes
3
+ from collections.abc import Iterator
4
+ from pathlib import Path
5
+ from typing import Literal
6
+ from urllib.parse import urlparse
7
+
8
+ import aiofiles
9
+ import jsonschema
10
+ import orjson
11
+ import typer
12
+ from gcloud.aio.storage import Storage
13
+ from openfoodfacts import Flavor
14
+ from openfoodfacts.images import generate_image_url
15
+ from tqdm.asyncio import tqdm
16
+
17
+ from labelr.sample.common import SampleMeta
18
+ from labelr.sample.llm import LLMImageExtractionSample
19
+ from labelr.utils import download_image_from_gcs
20
+
21
+ try:
22
+ import google.genai # noqa: F401
23
+ except ImportError:
24
+ raise ImportError(
25
+ "The 'google-genai' package is required to use this module. "
26
+ "Please install labelr with the 'google' extra: "
27
+ "`pip install labelr[google]`"
28
+ )
29
+ import aiohttp
30
+ from google import genai
31
+ from google.cloud import storage
32
+ from google.genai.types import CreateBatchJobConfig, HttpOptions
33
+ from google.genai.types import JSONSchema as GoogleJSONSchema
34
+ from google.genai.types import Schema as GoogleSchema
35
+ from openfoodfacts.types import JSONType
36
+ from pydantic import BaseModel
37
+
38
+
39
+ class RawBatchSamplePart(BaseModel):
40
+ type: Literal["text", "image"]
41
+ data: str
42
+
43
+
44
+ class RawBatchSample(BaseModel):
45
+ key: str
46
+ parts: list[RawBatchSamplePart]
47
+ meta: JSONType = {}
48
+
49
+
50
+ def convert_pydantic_model_to_google_schema(schema: type[BaseModel]) -> JSONType:
51
+ """Google doesn't support natively OpenAPI schemas, so we convert them to
52
+ Google `Schema` (a subset of OpenAPI)."""
53
+ return GoogleSchema.from_json_schema(
54
+ json_schema=GoogleJSONSchema.model_validate(schema.model_json_schema())
55
+ ).model_dump(mode="json", exclude_none=True, exclude_unset=True)
56
+
57
+
58
+ async def download_image(url: str, session: aiohttp.ClientSession) -> bytes:
59
+ """Download an image from a URL and return its content as bytes.
60
+
61
+ Args:
62
+ url (str): URL of the image to download.
63
+ Returns:
64
+ bytes: Content of the downloaded image.
65
+ """
66
+ async with session.get(url) as response:
67
+ response.raise_for_status()
68
+ return await response.read()
69
+
70
+
71
+ async def download_image_from_filesystem(url: str, base_dir: Path) -> bytes:
72
+ """Download an image from the filesystem and return its content as bytes.
73
+
74
+ Args:
75
+ url (str): URL of the image to download.
76
+ base_dir (Path): Base directory where images are stored.
77
+ Returns:
78
+ bytes: Content of the downloaded image.
79
+ """
80
+ file_path = urlparse(url).path[1:] # Remove leading '/'
81
+ full_file_path = base_dir / file_path
82
+ async with aiofiles.open(full_file_path, "rb") as f:
83
+ return await f.read()
84
+
85
+
86
+ async def upload_to_gcs(
87
+ image_url: str,
88
+ bucket_name: str,
89
+ blob_name: str,
90
+ session: aiohttp.ClientSession,
91
+ base_image_dir: Path | None = None,
92
+ ) -> dict:
93
+ """Upload data to Google Cloud Storage.
94
+ Args:
95
+ bucket_name (str): Name of the GCS bucket.
96
+ blob_name (str): Name of the blob (object) in the bucket.
97
+ data (bytes): Data to upload.
98
+ session (aiohttp.ClientSession): HTTP session to use for downloading
99
+ the image.
100
+ base_image_dir (Path | None): If provided, images will be read from
101
+ the filesystem under this base directory instead of downloading
102
+ them from their URLs.
103
+ Returns:
104
+ dict: Status of the upload operation.
105
+ """
106
+ if base_image_dir is None:
107
+ image_data = await download_image(image_url, session)
108
+ else:
109
+ image_data = await download_image_from_filesystem(image_url, base_image_dir)
110
+
111
+ client = Storage(session=session)
112
+
113
+ status = await client.upload(
114
+ bucket_name,
115
+ blob_name,
116
+ image_data,
117
+ )
118
+ return status
119
+
120
+
121
+ async def upload_to_gcs_format_async(
122
+ sample: RawBatchSample,
123
+ google_json_schema: JSONType,
124
+ instructions: str | None,
125
+ bucket_name: str,
126
+ bucket_dir_name: str,
127
+ session: aiohttp.ClientSession,
128
+ base_image_dir: Path | None = None,
129
+ skip_upload: bool = False,
130
+ thinking_level: str | None = None,
131
+ ) -> JSONType | None:
132
+ parts: list[JSONType] = []
133
+
134
+ if instructions:
135
+ parts.append({"text": instructions})
136
+
137
+ for part in sample.parts:
138
+ if part.type == "image":
139
+ mime_type, _ = mimetypes.guess_type(part.data)
140
+ if mime_type is None:
141
+ raise ValueError(f"Cannot guess mimetype of file: {part.data}")
142
+
143
+ file_uri = part.data
144
+ image_blob_name = f"{bucket_dir_name}/{sample.key}/{Path(file_uri).name}"
145
+ # Download the image from the URL
146
+ if not skip_upload:
147
+ try:
148
+ await upload_to_gcs(
149
+ image_url=file_uri,
150
+ bucket_name=bucket_name,
151
+ blob_name=image_blob_name,
152
+ session=session,
153
+ base_image_dir=base_image_dir,
154
+ )
155
+ except FileNotFoundError:
156
+ return None
157
+
158
+ parts.append(
159
+ {
160
+ "file_data": {
161
+ "file_uri": f"gs://{bucket_name}/{image_blob_name}",
162
+ "mime_type": mime_type,
163
+ }
164
+ }
165
+ )
166
+ else:
167
+ parts.append({"text": part.data})
168
+
169
+ generation_config = {
170
+ "responseMimeType": "application/json",
171
+ "response_json_schema": google_json_schema,
172
+ }
173
+
174
+ if thinking_level is not None:
175
+ generation_config["thinkingConfig"] = {"thinkingLevel": thinking_level}
176
+
177
+ return {
178
+ "key": f"key:{sample.key}",
179
+ "request": {
180
+ "contents": [
181
+ {
182
+ "parts": parts,
183
+ "role": "user",
184
+ }
185
+ ],
186
+ "generationConfig": generation_config,
187
+ },
188
+ }
189
+
190
+
191
+ async def generate_batch_dataset(
192
+ data_path: Path,
193
+ output_path: Path,
194
+ google_json_schema: JSONType,
195
+ instructions: str | None,
196
+ bucket_name: str,
197
+ bucket_dir_name: str,
198
+ max_concurrent_uploads: int = 30,
199
+ base_image_dir: Path | None = None,
200
+ from_key: str | None = None,
201
+ skip_upload: bool = False,
202
+ thinking_level: str | None = None,
203
+ ):
204
+ limiter = asyncio.Semaphore(max_concurrent_uploads)
205
+ ignore = True if from_key is None else False
206
+ missing_files = 0
207
+ async with aiohttp.ClientSession() as session:
208
+ async with asyncio.TaskGroup() as tg:
209
+ async with (
210
+ aiofiles.open(data_path, "r") as input_file,
211
+ aiofiles.open(output_path, "wb") as output_file,
212
+ ):
213
+ async with limiter:
214
+ tasks = set()
215
+ async for line in tqdm(input_file, desc="samples"):
216
+ # print(f"line: {line}")
217
+ sample = RawBatchSample.model_validate_json(line)
218
+ # print(f"sample: {sample}")
219
+ record_key = sample.key
220
+ if from_key is not None and ignore:
221
+ if record_key == from_key:
222
+ ignore = False
223
+ else:
224
+ continue
225
+ task = tg.create_task(
226
+ upload_to_gcs_format_async(
227
+ sample=sample,
228
+ google_json_schema=google_json_schema,
229
+ instructions=instructions,
230
+ bucket_name=bucket_name,
231
+ bucket_dir_name=bucket_dir_name,
232
+ session=session,
233
+ base_image_dir=base_image_dir,
234
+ skip_upload=skip_upload,
235
+ thinking_level=thinking_level,
236
+ )
237
+ )
238
+ tasks.add(task)
239
+
240
+ if len(tasks) >= max_concurrent_uploads:
241
+ for task in tasks:
242
+ await task
243
+ updated_record = task.result()
244
+ if updated_record is not None:
245
+ await output_file.write(
246
+ orjson.dumps(updated_record) + "\n".encode()
247
+ )
248
+ else:
249
+ missing_files += 1
250
+ tasks.clear()
251
+
252
+ for task in tasks:
253
+ await task
254
+ updated_record = task.result()
255
+ if updated_record is not None:
256
+ await output_file.write(
257
+ orjson.dumps(updated_record) + "\n".encode()
258
+ )
259
+ else:
260
+ missing_files += 1
261
+
262
+ typer.echo(
263
+ f"Upload and dataset update completed. Wrote updated dataset to {output_path}. "
264
+ f"Missing files: {missing_files}."
265
+ )
266
+
267
+
268
+ def launch_batch_job(
269
+ run_name: str,
270
+ dataset_path: Path,
271
+ model: str,
272
+ location: str,
273
+ ):
274
+ """Launch a Gemini Batch Inference job.
275
+
276
+ Args:
277
+ run_name (str): Name of the batch run.
278
+ dataset_path (Path): Path to the dataset file in JSONL format.
279
+ model (str): Model to use for the batch job. Example:
280
+ 'gemini-2.5-flash'.
281
+ location (str): Location for the Vertex AI resources. Example:
282
+ 'europe-west4'.
283
+ """
284
+ # We upload the dataset to a GCS bucket using the Gcloud
285
+
286
+ if model == "gemini-3-pro-preview" and location != "global":
287
+ typer.echo(
288
+ "Warning: only 'global' location is supported for 'gemini-3-pro-preview' model. Overriding location to 'global'."
289
+ )
290
+ location = "global"
291
+
292
+ storage_client = storage.Client()
293
+ bucket_name = "robotoff-batch" # Replace with your bucket name
294
+ run_dir = f"gemini-batch/{run_name}"
295
+ input_file_blob_name = f"{run_dir}/inputs.jsonl"
296
+ bucket = storage_client.bucket(bucket_name)
297
+ blob = bucket.blob(input_file_blob_name)
298
+ blob.upload_from_filename(dataset_path)
299
+
300
+ client = genai.Client(
301
+ http_options=HttpOptions(api_version="v1"),
302
+ vertexai=True,
303
+ location=location,
304
+ )
305
+ output_uri = f"gs://{bucket_name}/{run_dir}"
306
+ job = client.batches.create(
307
+ model=model,
308
+ src=f"gs://{bucket_name}/{input_file_blob_name}",
309
+ config=CreateBatchJobConfig(dest=output_uri),
310
+ )
311
+ print(job)
312
+
313
+
314
+ def generate_sample_iter(
315
+ prediction_path: Path,
316
+ json_schema: JSONType,
317
+ skip: int = 0,
318
+ limit: int | None = None,
319
+ is_openfoodfacts_dataset: bool = False,
320
+ openfoodfacts_flavor: Flavor = Flavor.off,
321
+ raise_on_invalid_sample: bool = False,
322
+ ) -> Iterator[LLMImageExtractionSample]:
323
+ """Generate training samples from a Gemini Batch Inference prediction
324
+ JSONL file.
325
+
326
+ Args:
327
+ prediction_path (Path): Path to the prediction JSONL file.
328
+ json_schema (JSONType): JSON schema to validate the predictions.
329
+ skip (int): Number of initial samples to skip.
330
+ limit (int | None): Maximum number of samples to generate.
331
+ is_openfoodfacts_dataset (bool): Whether the dataset is from Open Food
332
+ Facts.
333
+ openfoodfacts_flavor (Flavor): Flavor of the Open Food Facts dataset.
334
+ Yields:
335
+ Iterator[LLMImageExtractionSample]: Generated samples.
336
+ """
337
+ skipped = 0
338
+ invalid = 0
339
+ storage_client = storage.Client()
340
+ with prediction_path.open("r") as f_in:
341
+ for i, sample_str in enumerate(f_in):
342
+ if i < skip:
343
+ skipped += 1
344
+ continue
345
+ if limit is not None and i >= skip + limit:
346
+ break
347
+ sample = orjson.loads(sample_str)
348
+ try:
349
+ yield generate_sample_from_prediction(
350
+ json_schema=json_schema,
351
+ sample=sample,
352
+ is_openfoodfacts_dataset=is_openfoodfacts_dataset,
353
+ openfoodfacts_flavor=openfoodfacts_flavor,
354
+ storage_client=storage_client,
355
+ )
356
+ except Exception as e:
357
+ if raise_on_invalid_sample:
358
+ raise
359
+ else:
360
+ typer.echo(
361
+ f"Skipping invalid sample at line {i + 1} in {prediction_path}: {e}"
362
+ )
363
+ invalid += 1
364
+ continue
365
+ if skipped > 0:
366
+ typer.echo(f"Skipped {skipped} samples.")
367
+ if invalid > 0:
368
+ typer.echo(f"Skipped {invalid} invalid samples.")
369
+
370
+
371
+ def generate_sample_from_prediction(
372
+ json_schema: JSONType,
373
+ sample: JSONType,
374
+ is_openfoodfacts_dataset: bool = False,
375
+ openfoodfacts_flavor: Flavor = Flavor.off,
376
+ storage_client: storage.Client | None = None,
377
+ ) -> LLMImageExtractionSample:
378
+ """Generate a LLMImageExtractionSample from a prediction sample.
379
+ Args:
380
+ json_schema (JSONType): JSON schema to validate the predictions.
381
+ sample (JSONType): Prediction sample.
382
+ is_openfoodfacts_dataset (bool): Whether the dataset is from Open Food
383
+ Facts.
384
+ openfoodfacts_flavor (Flavor): Flavor of the Open Food Facts dataset.
385
+ storage_client (storage.Client | None): Optional Google Cloud Storage
386
+ client. If not provided, a new client will be created.
387
+ Returns:
388
+ LLMImageExtractionSample: Generated sample.
389
+ """
390
+ image_id = sample["key"][len("key:") :]
391
+ response_str = sample["response"]["candidates"][0]["content"]["parts"][0]["text"]
392
+ image_uri = sample["request"]["contents"][0]["parts"][1]["file_data"]["file_uri"]
393
+ image = download_image_from_gcs(image_uri=image_uri, client=storage_client)
394
+ response = orjson.loads(response_str)
395
+ jsonschema.validate(response, json_schema)
396
+
397
+ if is_openfoodfacts_dataset:
398
+ image_stem_parts = image_id.split("_")
399
+ barcode = image_stem_parts[0]
400
+ off_image_id = image_stem_parts[1]
401
+ image_id = f"{barcode}_{off_image_id}"
402
+ image_url = generate_image_url(
403
+ barcode, off_image_id, flavor=openfoodfacts_flavor
404
+ )
405
+ else:
406
+ image_id = image_id
407
+ barcode = ""
408
+ off_image_id = ""
409
+ image_url = ""
410
+
411
+ sample_meta = SampleMeta(
412
+ barcode=barcode,
413
+ off_image_id=off_image_id,
414
+ image_url=image_url,
415
+ )
416
+ return LLMImageExtractionSample(
417
+ image_id=image_id,
418
+ image=image,
419
+ output=orjson.dumps(response).decode("utf-8"),
420
+ meta=sample_meta,
421
+ )
labelr/main.py CHANGED
@@ -5,6 +5,7 @@ from openfoodfacts.utils import get_logger
5
5
 
6
6
  from labelr.apps import datasets as dataset_app
7
7
  from labelr.apps import evaluate as evaluate_app
8
+ from labelr.apps import google_batch as google_batch_app
8
9
  from labelr.apps import hugging_face as hf_app
9
10
  from labelr.apps import label_studio as ls_app
10
11
  from labelr.apps import train as train_app
@@ -84,6 +85,11 @@ app.add_typer(
84
85
  name="evaluate",
85
86
  help="Visualize and evaluate trained models.",
86
87
  )
88
+ app.add_typer(
89
+ google_batch_app.app,
90
+ name="google-batch",
91
+ help="Generate datasets and launch batch jobs on Google Gemini.",
92
+ )
87
93
 
88
94
  if __name__ == "__main__":
89
95
  app()
File without changes
@@ -0,0 +1,17 @@
1
+ import datasets
2
+
3
+ HF_DS_CLASSIFICATION_FEATURES = datasets.Features(
4
+ {
5
+ "image_id": datasets.Value("string"),
6
+ "image": datasets.features.Image(),
7
+ "width": datasets.Value("int64"),
8
+ "height": datasets.Value("int64"),
9
+ "meta": {
10
+ "barcode": datasets.Value("string"),
11
+ "off_image_id": datasets.Value("string"),
12
+ "image_url": datasets.Value("string"),
13
+ },
14
+ "category_id": datasets.Value("int64"),
15
+ "category_name": datasets.Value("string"),
16
+ }
17
+ )
@@ -0,0 +1,14 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class SampleMeta(BaseModel):
5
+ barcode: str | None = Field(
6
+ ..., description="The barcode of the product, if applicable"
7
+ )
8
+ off_image_id: str | None = Field(
9
+ ...,
10
+ description="The Open Food Facts image ID associated with the image, if applicable",
11
+ )
12
+ image_url: str | None = Field(
13
+ ..., description="The URL of the image, if applicable"
14
+ )