active-vision 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
active_vision/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- __version__ = "0.0.5"
1
+ __version__ = "0.1.0"
2
2
 
3
3
  from .core import *
active_vision/core.py CHANGED
@@ -1,17 +1,6 @@
1
1
  import pandas as pd
2
2
  from loguru import logger
3
- from fastai.callback.all import ShowGraphCallback
4
- from fastai.vision.all import (
5
- ImageDataLoaders,
6
- aug_transforms,
7
- Resize,
8
- vision_learner,
9
- accuracy,
10
- valley,
11
- slide,
12
- minimum,
13
- steep,
14
- )
3
+ from fastai.vision.all import *
15
4
  import torch
16
5
  import torch.nn.functional as F
17
6
 
@@ -22,7 +11,28 @@ warnings.filterwarnings("ignore", category=FutureWarning)
22
11
 
23
12
 
24
13
  class ActiveLearner:
25
- def __init__(self, model_name: str):
14
+ """
15
+ Active Learning framework for computer vision tasks.
16
+
17
+ Attributes:
18
+ Model Related:
19
+ model: The base model architecture (str or Callable)
20
+ learn: fastai Learner object for training
21
+ lrs: Learning rate finder results
22
+
23
+ Data Related:
24
+ train_set (pd.DataFrame): Training dataset
25
+ eval_set (pd.DataFrame): Evaluation dataset with ground truth labels
26
+ dls: fastai DataLoaders object
27
+ class_names: List of class names from the dataset
28
+ num_classes (int): Number of classes in the dataset
29
+
30
+ Prediction Related:
31
+ pred_df (pd.DataFrame): Predictions on a dataframe
32
+ eval_df (pd.DataFrame): Predictions on evaluation data
33
+ """
34
+
35
+ def __init__(self, model_name: str | Callable):
26
36
  self.model = self.load_model(model_name)
27
37
 
28
38
  def load_model(self, model_name: str | Callable):
@@ -43,6 +53,7 @@ class ActiveLearner:
43
53
  batch_size: int = 16,
44
54
  image_size: int = 224,
45
55
  batch_tfms: Callable = None,
56
+ learner_path: str = None,
46
57
  ):
47
58
  logger.info(f"Loading dataset from {filepath_col} and {label_col}")
48
59
  self.train_set = df.copy()
@@ -58,22 +69,66 @@ class ActiveLearner:
58
69
  item_tfms=Resize(image_size),
59
70
  batch_tfms=batch_tfms,
60
71
  )
61
- logger.info("Creating learner")
62
- self.learn = vision_learner(self.dls, self.model, metrics=accuracy).to_fp16()
72
+
73
+ if learner_path:
74
+ logger.info(f"Loading learner from {learner_path}")
75
+ gpu_available = torch.cuda.is_available()
76
+ if gpu_available:
77
+ logger.info(f"Loading learner on GPU.")
78
+ else:
79
+ logger.info(f"Loading learner on CPU.")
80
+
81
+ self.learn = load_learner(learner_path, cpu=not gpu_available)
82
+ else:
83
+ logger.info("Creating learner")
84
+ self.learn = vision_learner(
85
+ self.dls, self.model, metrics=accuracy
86
+ ).to_fp16()
87
+
63
88
  self.class_names = self.dls.vocab
89
+ self.num_classes = self.dls.c
64
90
  logger.info("Done. Ready to train.")
65
91
 
66
- def show_batch(self):
67
- self.dls.show_batch()
92
+ def show_batch(
93
+ self,
94
+ num_samples: int = 9,
95
+ unique: bool = False,
96
+ num_rows: int = None,
97
+ num_cols: int = None,
98
+ ):
99
+ """
100
+ Show a batch of images from the dataset.
101
+
102
+ Args:
103
+ num_samples: Number of samples to show.
104
+ unique: Whether to show unique samples.
105
+ num_rows: Number of rows in the grid.
106
+ num_cols: Number of columns in the grid.
107
+ """
108
+ self.dls.show_batch(
109
+ max_n=num_samples, unique=unique, nrows=num_rows, ncols=num_cols
110
+ )
68
111
 
69
112
  def lr_find(self):
70
113
  logger.info("Finding optimal learning rate")
71
114
  self.lrs = self.learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
72
115
  logger.info(f"Optimal learning rate: {self.lrs.valley}")
73
116
 
74
- def train(self, epochs: int, lr: float):
75
- logger.info(f"Training for {epochs} epochs with learning rate: {lr}")
76
- self.learn.fine_tune(epochs, lr, cbs=[ShowGraphCallback()])
117
+ def train(self, epochs: int, lr: float, head_tuning_epochs: int = 1):
118
+ """
119
+ Train the model.
120
+
121
+ Args:
122
+ epochs: Number of epochs to train for.
123
+ lr: Learning rate.
124
+ head_tuning_epochs: Number of epochs to train the head.
125
+ """
126
+ logger.info(f"Training head for {head_tuning_epochs} epochs")
127
+ logger.info(f"Training model end-to-end for {epochs} epochs")
128
+ logger.info(f"Learning rate: {lr} with one-cycle learning rate scheduler")
129
+ self.learn.fine_tune(
130
+ epochs, lr, freeze_epochs=head_tuning_epochs, cbs=[ShowGraphCallback()]
131
+ )
77
132
 
78
133
  def predict(self, filepaths: list[str], batch_size: int = 16):
79
134
  """
@@ -131,11 +186,17 @@ class ActiveLearner:
131
186
  """
132
187
 
133
188
  # Remove samples that is already in the training set
134
- df = df[~df["filepath"].isin(self.train_set["filepath"])]
189
+ df = df[~df["filepath"].isin(self.train_set["filepath"])].copy()
135
190
 
136
191
  if strategy == "least-confidence":
137
192
  logger.info(f"Getting top {num_samples} low confidence samples")
138
- uncertain_df = df.sort_values(by="pred_conf", ascending=True).head(
193
+
194
+ df.loc[:, "uncertainty_score"] = 1 - (df["pred_conf"]) / (
195
+ self.num_classes - (self.num_classes - 1)
196
+ )
197
+
198
+ # Sort by descending uncertainty score
199
+ uncertain_df = df.sort_values(by="uncertainty_score", ascending=False).head(
139
200
  num_samples
140
201
  )
141
202
  return uncertain_df
@@ -197,15 +258,15 @@ class ActiveLearner:
197
258
  return;
198
259
  }
199
260
 
200
- if (e.key.toLowerCase() == "w") {
261
+ if (e.key === "ArrowUp" || e.key === "Enter") {
201
262
  document.getElementById("submit_btn").click();
202
- } else if (e.key.toLowerCase() == "d") {
263
+ } else if (e.key === "ArrowRight") {
203
264
  document.getElementById("next_btn").click();
204
- } else if (e.key.toLowerCase() == "a") {
265
+ } else if (e.key === "ArrowLeft") {
205
266
  document.getElementById("back_btn").click();
206
267
  }
207
268
  }
208
- document.addEventListener('keypress', shortcuts, false);
269
+ document.addEventListener('keydown', shortcuts, false);
209
270
  </script>
210
271
  """
211
272
 
@@ -216,24 +277,45 @@ class ActiveLearner:
216
277
  with gr.Blocks(head=shortcut_js) as demo:
217
278
  current_index = gr.State(value=0)
218
279
 
219
- filename = gr.Textbox(
220
- label="Filename", value=filepaths[0], interactive=False
221
- )
222
-
223
280
  image = gr.Image(
224
281
  type="filepath", label="Image", value=filepaths[0], height=500
225
282
  )
226
- category = gr.Radio(choices=self.class_names, label="Select Category")
227
283
 
228
284
  with gr.Row():
229
- back_btn = gr.Button("← Previous (A)", elem_id="back_btn")
285
+ filename = gr.Textbox(
286
+ label="Filename", value=filepaths[0], interactive=False
287
+ )
288
+
289
+ pred_label = gr.Textbox(
290
+ label="Predicted Label",
291
+ value=df["pred_label"].iloc[0]
292
+ if "pred_label" in df.columns
293
+ else "",
294
+ interactive=False,
295
+ )
296
+ pred_conf = gr.Textbox(
297
+ label="Confidence",
298
+ value=f"{df['pred_conf'].iloc[0]:.2%}"
299
+ if "pred_conf" in df.columns
300
+ else "",
301
+ interactive=False,
302
+ )
303
+
304
+ category = gr.Radio(
305
+ choices=self.class_names,
306
+ label="Select Category",
307
+ value=df["pred_label"].iloc[0] if "pred_label" in df.columns else None,
308
+ )
309
+
310
+ with gr.Row():
311
+ back_btn = gr.Button("← Previous", elem_id="back_btn")
230
312
  submit_btn = gr.Button(
231
- "Submit (W)",
313
+ "Submit (↑/Enter)",
232
314
  variant="primary",
233
315
  elem_id="submit_btn",
234
316
  interactive=False,
235
317
  )
236
- next_btn = gr.Button("Next → (D)", elem_id="next_btn")
318
+ next_btn = gr.Button("Next →", elem_id="next_btn")
237
319
 
238
320
  progress = gr.Slider(
239
321
  minimum=0,
@@ -245,6 +327,73 @@ class ActiveLearner:
245
327
 
246
328
  finish_btn = gr.Button("Finish Labeling", variant="primary")
247
329
 
330
+ with gr.Accordion("Zero-Shot Inference", open=False) as zero_shot_accordion:
331
+ gr.Markdown("""
332
+ Uses a VLM to predict the label of the image.
333
+ """)
334
+
335
+ import xinfer
336
+ from xinfer.model_registry import model_registry
337
+ from xinfer.types import ModelInputOutput
338
+
339
+ # Get models and filter for image-to-text models
340
+ all_models = model_registry.list_models()
341
+ model_list = [
342
+ model.id
343
+ for model in all_models
344
+ if model.input_output == ModelInputOutput.IMAGE_TEXT_TO_TEXT
345
+ ]
346
+
347
+ with gr.Row():
348
+ with gr.Row():
349
+ model_dropdown = gr.Dropdown(
350
+ choices=model_list,
351
+ label="Select a model",
352
+ value="vikhyatk/moondream2",
353
+ )
354
+ device_dropdown = gr.Dropdown(
355
+ choices=["cuda", "cpu"],
356
+ label="Device",
357
+ value="cuda" if torch.cuda.is_available() else "cpu",
358
+ )
359
+ dtype_dropdown = gr.Dropdown(
360
+ choices=["float32", "float16", "bfloat16"],
361
+ label="Data Type",
362
+ value="float16" if torch.cuda.is_available() else "float32",
363
+ )
364
+
365
+ with gr.Column():
366
+ prompt_textbox = gr.Textbox(
367
+ label="Prompt",
368
+ lines=3,
369
+ value=f"Classify the image into one of the following categories: {self.class_names}",
370
+ interactive=True,
371
+ )
372
+ inference_btn = gr.Button("Run Inference", variant="primary")
373
+
374
+ result_textbox = gr.Textbox(
375
+ label="Result",
376
+ lines=3,
377
+ interactive=False,
378
+ )
379
+
380
+ def run_zero_shot_inference(prompt, model, device, dtype, current_filename):
381
+ model = xinfer.create_model(model, device=device, dtype=dtype)
382
+ result = model.infer(current_filename, prompt).text
383
+ return result
384
+
385
+ inference_btn.click(
386
+ fn=run_zero_shot_inference,
387
+ inputs=[
388
+ prompt_textbox,
389
+ model_dropdown,
390
+ device_dropdown,
391
+ dtype_dropdown,
392
+ filename,
393
+ ],
394
+ outputs=[result_textbox],
395
+ )
396
+
248
397
  def update_submit_btn(choice):
249
398
  return gr.Button(interactive=choice is not None)
250
399
 
@@ -253,21 +402,59 @@ class ActiveLearner:
253
402
  )
254
403
 
255
404
  def navigate(current_idx, direction):
405
+ # Convert current_idx to int before arithmetic
406
+ current_idx = int(current_idx)
256
407
  next_idx = current_idx + direction
408
+
257
409
  if 0 <= next_idx < len(filepaths):
258
- return filepaths[next_idx], filepaths[next_idx], next_idx, next_idx
410
+ return (
411
+ filepaths[next_idx],
412
+ filepaths[next_idx],
413
+ df["pred_label"].iloc[next_idx]
414
+ if "pred_label" in df.columns
415
+ else "",
416
+ f"{df['pred_conf'].iloc[next_idx]:.2%}"
417
+ if "pred_conf" in df.columns
418
+ else "",
419
+ df["pred_label"].iloc[next_idx]
420
+ if "pred_label" in df.columns
421
+ else None,
422
+ next_idx,
423
+ next_idx,
424
+ )
259
425
  return (
260
426
  filepaths[current_idx],
261
427
  filepaths[current_idx],
428
+ df["pred_label"].iloc[current_idx]
429
+ if "pred_label" in df.columns
430
+ else "",
431
+ f"{df['pred_conf'].iloc[current_idx]:.2%}"
432
+ if "pred_conf" in df.columns
433
+ else "",
434
+ df["pred_label"].iloc[current_idx]
435
+ if "pred_label" in df.columns
436
+ else None,
262
437
  current_idx,
263
438
  current_idx,
264
439
  )
265
440
 
266
441
  def save_and_next(current_idx, selected_category):
442
+ # Convert current_idx to int before arithmetic
443
+ current_idx = int(current_idx)
444
+
267
445
  if selected_category is None:
268
446
  return (
269
447
  filepaths[current_idx],
270
448
  filepaths[current_idx],
449
+ df["pred_label"].iloc[current_idx]
450
+ if "pred_label" in df.columns
451
+ else "",
452
+ f"{df['pred_conf'].iloc[current_idx]:.2%}"
453
+ if "pred_conf" in df.columns
454
+ else "",
455
+ df["pred_label"].iloc[current_idx]
456
+ if "pred_label" in df.columns
457
+ else None,
271
458
  current_idx,
272
459
  current_idx,
273
460
  )
@@ -282,10 +469,33 @@ class ActiveLearner:
282
469
  return (
283
470
  filepaths[current_idx],
284
471
  filepaths[current_idx],
472
+ df["pred_label"].iloc[current_idx]
473
+ if "pred_label" in df.columns
474
+ else "",
475
+ f"{df['pred_conf'].iloc[current_idx]:.2%}"
476
+ if "pred_conf" in df.columns
477
+ else "",
478
+ df["pred_label"].iloc[current_idx]
479
+ if "pred_label" in df.columns
480
+ else None,
285
481
  current_idx,
286
482
  current_idx,
287
483
  )
288
- return filepaths[next_idx], filepaths[next_idx], next_idx, next_idx
484
+ return (
485
+ filepaths[next_idx],
486
+ filepaths[next_idx],
487
+ df["pred_label"].iloc[next_idx]
488
+ if "pred_label" in df.columns
489
+ else "",
490
+ f"{df['pred_conf'].iloc[next_idx]:.2%}"
491
+ if "pred_conf" in df.columns
492
+ else "",
493
+ df["pred_label"].iloc[next_idx]
494
+ if "pred_label" in df.columns
495
+ else None,
496
+ next_idx,
497
+ next_idx,
498
+ )
289
499
 
290
500
  def convert_csv_to_parquet():
291
501
  try:
@@ -301,19 +511,43 @@ class ActiveLearner:
301
511
  back_btn.click(
302
512
  fn=lambda idx: navigate(idx, -1),
303
513
  inputs=[current_index],
304
- outputs=[filename, image, current_index, progress],
514
+ outputs=[
515
+ filename,
516
+ image,
517
+ pred_label,
518
+ pred_conf,
519
+ category,
520
+ current_index,
521
+ progress,
522
+ ],
305
523
  )
306
524
 
307
525
  next_btn.click(
308
526
  fn=lambda idx: navigate(idx, 1),
309
527
  inputs=[current_index],
310
- outputs=[filename, image, current_index, progress],
528
+ outputs=[
529
+ filename,
530
+ image,
531
+ pred_label,
532
+ pred_conf,
533
+ category,
534
+ current_index,
535
+ progress,
536
+ ],
311
537
  )
312
538
 
313
539
  submit_btn.click(
314
540
  fn=save_and_next,
315
541
  inputs=[current_index, category],
316
- outputs=[filename, image, current_index, progress],
542
+ outputs=[
543
+ filename,
544
+ image,
545
+ pred_label,
546
+ pred_conf,
547
+ category,
548
+ current_index,
549
+ progress,
550
+ ],
317
551
  )
318
552
 
319
553
  finish_btn.click(fn=convert_csv_to_parquet)
@@ -325,10 +559,6 @@ class ActiveLearner:
325
559
  Add samples to the training set.
326
560
  """
327
561
  new_train_set = df.copy()
328
- # new_train_set.drop(columns=["pred_conf"], inplace=True)
329
- # new_train_set.rename(columns={"pred_label": "label"}, inplace=True)
330
-
331
- # len_old = len(self.train_set)
332
562
 
333
563
  logger.info(f"Adding {len(new_train_set)} samples to training set")
334
564
  self.train_set = pd.concat([self.train_set, new_train_set])
@@ -340,13 +570,3 @@ class ActiveLearner:
340
570
 
341
571
  self.train_set.to_parquet(f"{output_filename}.parquet")
342
572
  logger.info(f"Saved training set to {output_filename}.parquet")
343
-
344
- # if len(self.train_set) == len_old:
345
- # logger.warning("No new samples added to training set")
346
-
347
- # elif len_old + len(new_train_set) < len(self.train_set):
348
- # logger.warning("Some samples were duplicates and removed from training set")
349
-
350
- # else:
351
- # logger.info("All new samples added to training set")
352
- # logger.info(f"Training set now has {len(self.train_set)} samples")
@@ -1,10 +1,11 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: active-vision
3
- Version: 0.0.5
3
+ Version: 0.1.0
4
4
  Summary: Active learning for edge vision.
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
7
7
  License-File: LICENSE
8
+ Requires-Dist: accelerate>=1.2.1
8
9
  Requires-Dist: datasets>=3.2.0
9
10
  Requires-Dist: fastai>=2.7.18
10
11
  Requires-Dist: gradio>=5.12.0
@@ -13,6 +14,8 @@ Requires-Dist: ipywidgets>=8.1.5
13
14
  Requires-Dist: loguru>=0.7.3
14
15
  Requires-Dist: seaborn>=0.13.2
15
16
  Requires-Dist: timm>=1.0.13
17
+ Requires-Dist: transformers>=4.48.0
18
+ Requires-Dist: xinfer>=0.3.2
16
19
 
17
20
  ![Python Version](https://img.shields.io/badge/python-3.10%2B-blue?style=for-the-badge)
18
21
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg?style=for-the-badge)
@@ -0,0 +1,7 @@
1
+ active_vision/__init__.py,sha256=dDQijes3C7zAUc_08TyblLSP6Lk0PcPPI8PYgEliKCI,43
2
+ active_vision/core.py,sha256=D_ve-nMv2EWSaQCOBTggleo-1op8JHXchk0QLicGDqg,21715
3
+ active_vision-0.1.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
4
+ active_vision-0.1.0.dist-info/METADATA,sha256=aA793OK3PGKnKVchMQthXl1H14xcBh_kq9tAO9o6jf0,15944
5
+ active_vision-0.1.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
6
+ active_vision-0.1.0.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
7
+ active_vision-0.1.0.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- active_vision/__init__.py,sha256=u-7eEtxmLFoQfY0fM9JSs_lWb4e1c7WxR3cC619BTXE,43
2
- active_vision/core.py,sha256=mKS-ZZunjPgXuavm_J4oYiO9lm6UNRjFEzIn4kNfdVA,13421
3
- active_vision-0.0.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
4
- active_vision-0.0.5.dist-info/METADATA,sha256=mSFB-DeJ43roTwswTp3oHcG3CIyKnO-7ZCqaYbw26eQ,15846
5
- active_vision-0.0.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
6
- active_vision-0.0.5.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
7
- active_vision-0.0.5.dist-info/RECORD,,