active-vision 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
active_vision/__init__.py CHANGED
@@ -1 +1,3 @@
1
- __version__ = "0.0.1"
1
+ __version__ = "0.0.2"
2
+
3
+ from .core import *
active_vision/core.py ADDED
@@ -0,0 +1,149 @@
1
+ import pandas as pd
2
+ from loguru import logger
3
+ from fastai.vision.models import resnet18, resnet34
4
+ from fastai.callback.all import ShowGraphCallback
5
+ from fastai.vision.all import (
6
+ ImageDataLoaders,
7
+ aug_transforms,
8
+ Resize,
9
+ vision_learner,
10
+ accuracy,
11
+ valley,
12
+ slide,
13
+ minimum,
14
+ steep,
15
+ )
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+ import warnings
20
+
21
+ warnings.filterwarnings("ignore", category=FutureWarning)
22
+
23
+
24
+ class ActiveLearner:
25
+ def __init__(self, model_name: str):
26
+ self.model = self.load_model(model_name)
27
+
28
+ def load_model(self, model_name: str):
29
+ models = {"resnet18": resnet18, "resnet34": resnet34}
30
+ logger.info(f"Loading model {model_name}")
31
+ if model_name not in models:
32
+ logger.error(f"Model {model_name} not found")
33
+ raise ValueError(f"Model {model_name} not found")
34
+ return models[model_name]
35
+
36
+ def load_dataset(
37
+ self,
38
+ df: pd.DataFrame,
39
+ filepath_col: str,
40
+ label_col: str,
41
+ valid_pct: float = 0.2,
42
+ batch_size: int = 16,
43
+ image_size: int = 224,
44
+ ):
45
+ logger.info(f"Loading dataset from {filepath_col} and {label_col}")
46
+ self.train_set = df.copy()
47
+
48
+ logger.info("Creating dataloaders")
49
+ self.dls = ImageDataLoaders.from_df(
50
+ df,
51
+ path=".",
52
+ valid_pct=valid_pct,
53
+ fn_col=filepath_col,
54
+ label_col=label_col,
55
+ bs=batch_size,
56
+ item_tfms=Resize(image_size),
57
+ batch_tfms=aug_transforms(size=image_size, min_scale=0.75),
58
+ )
59
+ logger.info("Creating learner")
60
+ self.learn = vision_learner(self.dls, self.model, metrics=accuracy).to_fp16()
61
+ self.class_names = self.dls.vocab
62
+ logger.info("Done. Ready to train.")
63
+
64
+ def lr_find(self):
65
+ logger.info("Finding optimal learning rate")
66
+ self.lrs = self.learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
67
+ logger.info(f"Optimal learning rate: {self.lrs.valley}")
68
+
69
+ def train(self, epochs: int, lr: float):
70
+ logger.info(f"Training for {epochs} epochs with learning rate: {lr}")
71
+ self.learn.fine_tune(epochs, lr, cbs=[ShowGraphCallback()])
72
+
73
+ def predict(self, filepaths: list[str], batch_size: int = 16):
74
+ """
75
+ Run inference on an unlabeled dataset. Returns a df with filepaths and predicted labels, and confidence scores.
76
+ """
77
+ logger.info(f"Running inference on {len(filepaths)} samples")
78
+ test_dl = self.dls.test_dl(filepaths, bs=batch_size)
79
+ preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
80
+
81
+ self.pred_df = pd.DataFrame(
82
+ {
83
+ "filepath": filepaths,
84
+ "pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
85
+ "pred_conf": torch.max(F.softmax(preds, dim=1), dim=1)[0].numpy(),
86
+ }
87
+ )
88
+ return self.pred_df
89
+
90
+ def evaluate(self, df: pd.DataFrame, filepath_col: str, label_col: str, batch_size: int = 16):
91
+ """
92
+ Evaluate on a labeled dataset. Returns a score.
93
+ """
94
+ self.eval_set = df.copy()
95
+
96
+ filepaths = self.eval_set[filepath_col].tolist()
97
+ labels = self.eval_set[label_col].tolist()
98
+ test_dl = self.dls.test_dl(filepaths, bs=batch_size)
99
+ preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
100
+
101
+ self.eval_df = pd.DataFrame(
102
+ {
103
+ "filepath": filepaths,
104
+ "label": labels,
105
+ "pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
106
+ }
107
+ )
108
+
109
+ accuracy = float((self.eval_df["label"] == self.eval_df["pred_label"]).mean())
110
+ logger.info(f"Accuracy: {accuracy:.2%}")
111
+ return accuracy
112
+
113
+ def sample_uncertain(self, df: pd.DataFrame, num_samples: int):
114
+ """
115
+ Sample top `num_samples` low confidence samples. Returns a df with filepaths and predicted labels, and confidence scores.
116
+ """
117
+ uncertain_df = df.sort_values(
118
+ by="pred_conf", ascending=True
119
+ ).head(num_samples)
120
+ return uncertain_df
121
+
122
+ def add_to_train_set(self, df: pd.DataFrame):
123
+ """
124
+ Add samples to the training set.
125
+ """
126
+ new_train_set = df.copy()
127
+ new_train_set.drop(columns=["pred_conf"], inplace=True)
128
+ new_train_set.rename(columns={"pred_label": "label"}, inplace=True)
129
+
130
+ len_old = len(self.train_set)
131
+
132
+ logger.info(f"Adding {len(new_train_set)} samples to training set")
133
+ self.train_set = pd.concat([self.train_set, new_train_set])
134
+
135
+ self.train_set = self.train_set.drop_duplicates(
136
+ subset=["filepath"], keep="last"
137
+ )
138
+ self.train_set.reset_index(drop=True, inplace=True)
139
+
140
+
141
+ if len(self.train_set) == len_old:
142
+ logger.warning("No new samples added to training set")
143
+
144
+ elif len_old + len(new_train_set) < len(self.train_set):
145
+ logger.warning("Some samples were duplicates and removed from training set")
146
+
147
+ else:
148
+ logger.info("All new samples added to training set")
149
+ logger.info(f"Training set now has {len(self.train_set)} samples")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: active-vision
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Active learning for edge vision.
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -14,9 +14,11 @@ Requires-Dist: seaborn>=0.13.2
14
14
 
15
15
  ![Python Version](https://img.shields.io/badge/python-3.10%2B-blue?style=for-the-badge)
16
16
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg?style=for-the-badge)
17
+ ![PyPI](https://img.shields.io/pypi/v/active-vision?style=for-the-badge)
18
+ ![Downloads](https://img.shields.io/pepy/dt/active-vision?style=for-the-badge&logo=pypi&logoColor=white&label=Downloads&color=purple)
17
19
 
18
20
  <p align="center">
19
- <img src="./assets/logo.png" alt="active-vision">
21
+ <img src="https://github.com/dnth/active-vision/blob/main/assets/logo.png" alt="active-vision">
20
22
  </p>
21
23
 
22
24
  Active learning at the edge for computer vision.
@@ -44,43 +46,37 @@ cd active-vision
44
46
  pip install -e .
45
47
  ```
46
48
 
47
- ## Usage [WIP]
49
+ ## Usage
50
+ See the [notebook](./nbs/end-to-end.ipynb) for a complete example.
48
51
 
49
52
  ```python
50
- import active_vision as av
53
+ from active_vision import ActiveLearner
54
+ import pandas as pd
51
55
 
52
- # Load a model
53
- model = av.load_model("resnet18")
56
+ # Create an active learner instance with a model
57
+ al = ActiveLearner("resnet18")
54
58
 
55
- # Load a dataset
56
- dataset = av.load_dataset(df)
57
-
58
- # Inital sampling
59
- dataset = av.initial_sampling(dataset, n_samples=10)
59
+ # Load the dataset into the active learner
60
+ train_df = pd.read_parquet("training_samples.parquet")
61
+ al.load_dataset(train_df, "filepath", "label")
60
62
 
61
63
  # Train the model
62
- model.train()
63
-
64
- # Save the model
65
- model.save()
66
-
67
- # Evaluate the model
68
- model.evaluate(df)
64
+ al.train(epochs=3, lr=1e-3)
69
65
 
70
- # Uncertainty sampling to get the lowest confidence images
71
- model.uncertainty_sampling()
66
+ # Load evaluation data
67
+ eval_df = pd.read_parquet("evaluation_samples.parquet")
72
68
 
73
- # Diversity sampling to get the most diverse images (outliers)
74
- model.diversity_sampling()
69
+ # Evaluate the model on a labeled evaluation set
70
+ accuracy = al.evaluate(eval_df, "filepath", "label")
75
71
 
76
- # Random sampling
77
- model.random_sampling()
72
+ # Get predictions from an unlabeled set
73
+ pred_df = al.predict(filepaths)
78
74
 
79
- # Merge the datasets
80
- dataset = av.merge_datasets(dataset, dataset_2)
75
+ # Sample low confidence predictions
76
+ uncertain_df = al.sample_uncertain(pred_df, num_samples=10)
81
77
 
82
- # Launch a streamlit app to label the images
83
- av.label_images(dataset)
78
+ # Add newly labeled data to training set
79
+ al.add_to_train_set(uncertain_df)
84
80
  ```
85
81
 
86
82
  ## Workflow
@@ -0,0 +1,7 @@
1
+ active_vision/__init__.py,sha256=5VE_DRQ_Rgbo7NlPh3-rP2pUClK48jGxPqAcptBscZ8,43
2
+ active_vision/core.py,sha256=RBVabC350wucYl7KJgIp3fc1pS9pxtG14iDb-ZyBJxI,5262
3
+ active_vision-0.0.2.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
4
+ active_vision-0.0.2.dist-info/METADATA,sha256=7_eqZJnGeIPjb4LLZ-Bqu1AMJ_h77_0bNRyS_COEv5w,8350
5
+ active_vision-0.0.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
6
+ active_vision-0.0.2.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
7
+ active_vision-0.0.2.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- active_vision/__init__.py,sha256=sXLh7g3KC4QCFxcZGBTpG2scR7hmmBsMjq6LqRptkRg,22
2
- active_vision-0.0.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
3
- active_vision-0.0.1.dist-info/METADATA,sha256=lPOTTVSPAaX3Rn9Q1ci_jgoQOC-HFpQIyTNqrouOYEs,7936
4
- active_vision-0.0.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
5
- active_vision-0.0.1.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
6
- active_vision-0.0.1.dist-info/RECORD,,