active-vision 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- active_vision/__init__.py +3 -1
- active_vision/core.py +291 -0
- {active_vision-0.0.1.dist-info → active_vision-0.0.3.dist-info}/METADATA +55 -36
- active_vision-0.0.3.dist-info/RECORD +7 -0
- active_vision-0.0.1.dist-info/RECORD +0 -6
- {active_vision-0.0.1.dist-info → active_vision-0.0.3.dist-info}/LICENSE +0 -0
- {active_vision-0.0.1.dist-info → active_vision-0.0.3.dist-info}/WHEEL +0 -0
- {active_vision-0.0.1.dist-info → active_vision-0.0.3.dist-info}/top_level.txt +0 -0
active_vision/__init__.py
CHANGED
active_vision/core.py
ADDED
@@ -0,0 +1,291 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
from loguru import logger
|
3
|
+
from fastai.vision.models import resnet18, resnet34
|
4
|
+
from fastai.callback.all import ShowGraphCallback
|
5
|
+
from fastai.vision.all import (
|
6
|
+
ImageDataLoaders,
|
7
|
+
aug_transforms,
|
8
|
+
Resize,
|
9
|
+
vision_learner,
|
10
|
+
accuracy,
|
11
|
+
valley,
|
12
|
+
slide,
|
13
|
+
minimum,
|
14
|
+
steep,
|
15
|
+
)
|
16
|
+
import torch
|
17
|
+
import torch.nn.functional as F
|
18
|
+
|
19
|
+
import warnings
|
20
|
+
|
21
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
22
|
+
|
23
|
+
|
24
|
+
class ActiveLearner:
|
25
|
+
def __init__(self, model_name: str):
|
26
|
+
self.model = self.load_model(model_name)
|
27
|
+
|
28
|
+
def load_model(self, model_name: str):
|
29
|
+
models = {"resnet18": resnet18, "resnet34": resnet34}
|
30
|
+
logger.info(f"Loading model {model_name}")
|
31
|
+
if model_name not in models:
|
32
|
+
logger.error(f"Model {model_name} not found")
|
33
|
+
raise ValueError(f"Model {model_name} not found")
|
34
|
+
return models[model_name]
|
35
|
+
|
36
|
+
def load_dataset(
|
37
|
+
self,
|
38
|
+
df: pd.DataFrame,
|
39
|
+
filepath_col: str,
|
40
|
+
label_col: str,
|
41
|
+
valid_pct: float = 0.2,
|
42
|
+
batch_size: int = 16,
|
43
|
+
image_size: int = 224,
|
44
|
+
):
|
45
|
+
logger.info(f"Loading dataset from {filepath_col} and {label_col}")
|
46
|
+
self.train_set = df.copy()
|
47
|
+
|
48
|
+
logger.info("Creating dataloaders")
|
49
|
+
self.dls = ImageDataLoaders.from_df(
|
50
|
+
df,
|
51
|
+
path=".",
|
52
|
+
valid_pct=valid_pct,
|
53
|
+
fn_col=filepath_col,
|
54
|
+
label_col=label_col,
|
55
|
+
bs=batch_size,
|
56
|
+
item_tfms=Resize(image_size),
|
57
|
+
batch_tfms=aug_transforms(size=image_size, min_scale=0.75),
|
58
|
+
)
|
59
|
+
logger.info("Creating learner")
|
60
|
+
self.learn = vision_learner(self.dls, self.model, metrics=accuracy).to_fp16()
|
61
|
+
self.class_names = self.dls.vocab
|
62
|
+
logger.info("Done. Ready to train.")
|
63
|
+
|
64
|
+
def lr_find(self):
|
65
|
+
logger.info("Finding optimal learning rate")
|
66
|
+
self.lrs = self.learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
|
67
|
+
logger.info(f"Optimal learning rate: {self.lrs.valley}")
|
68
|
+
|
69
|
+
def train(self, epochs: int, lr: float):
|
70
|
+
logger.info(f"Training for {epochs} epochs with learning rate: {lr}")
|
71
|
+
self.learn.fine_tune(epochs, lr, cbs=[ShowGraphCallback()])
|
72
|
+
|
73
|
+
def predict(self, filepaths: list[str], batch_size: int = 16):
|
74
|
+
"""
|
75
|
+
Run inference on an unlabeled dataset. Returns a df with filepaths and predicted labels, and confidence scores.
|
76
|
+
"""
|
77
|
+
logger.info(f"Running inference on {len(filepaths)} samples")
|
78
|
+
test_dl = self.dls.test_dl(filepaths, bs=batch_size)
|
79
|
+
preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
|
80
|
+
|
81
|
+
self.pred_df = pd.DataFrame(
|
82
|
+
{
|
83
|
+
"filepath": filepaths,
|
84
|
+
"pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
|
85
|
+
"pred_conf": torch.max(F.softmax(preds, dim=1), dim=1)[0].numpy(),
|
86
|
+
}
|
87
|
+
)
|
88
|
+
return self.pred_df
|
89
|
+
|
90
|
+
def evaluate(
|
91
|
+
self, df: pd.DataFrame, filepath_col: str, label_col: str, batch_size: int = 16
|
92
|
+
):
|
93
|
+
"""
|
94
|
+
Evaluate on a labeled dataset. Returns a score.
|
95
|
+
"""
|
96
|
+
self.eval_set = df.copy()
|
97
|
+
|
98
|
+
filepaths = self.eval_set[filepath_col].tolist()
|
99
|
+
labels = self.eval_set[label_col].tolist()
|
100
|
+
test_dl = self.dls.test_dl(filepaths, bs=batch_size)
|
101
|
+
preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
|
102
|
+
|
103
|
+
self.eval_df = pd.DataFrame(
|
104
|
+
{
|
105
|
+
"filepath": filepaths,
|
106
|
+
"label": labels,
|
107
|
+
"pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
|
108
|
+
}
|
109
|
+
)
|
110
|
+
|
111
|
+
accuracy = float((self.eval_df["label"] == self.eval_df["pred_label"]).mean())
|
112
|
+
logger.info(f"Accuracy: {accuracy:.2%}")
|
113
|
+
return accuracy
|
114
|
+
|
115
|
+
def sample_uncertain(self, df: pd.DataFrame, num_samples: int):
|
116
|
+
"""
|
117
|
+
Sample top `num_samples` low confidence samples. Returns a df with filepaths and predicted labels, and confidence scores.
|
118
|
+
"""
|
119
|
+
logger.info(f"Getting top {num_samples} low confidence samples")
|
120
|
+
uncertain_df = df.sort_values(by="pred_conf", ascending=True).head(num_samples)
|
121
|
+
return uncertain_df
|
122
|
+
|
123
|
+
def label(self, df: pd.DataFrame, output_filename: str = "labeled"):
|
124
|
+
"""
|
125
|
+
Launch a labeling interface for the user to label the samples.
|
126
|
+
Input is a df with filepaths listing the files to be labeled. Output is a df with filepaths and labels.
|
127
|
+
"""
|
128
|
+
import gradio as gr
|
129
|
+
|
130
|
+
shortcut_js = """
|
131
|
+
<script>
|
132
|
+
function shortcuts(e) {
|
133
|
+
// Only block shortcuts if we're in a text input or textarea
|
134
|
+
if (e.target.tagName.toLowerCase() === "textarea" ||
|
135
|
+
(e.target.tagName.toLowerCase() === "input" && e.target.type.toLowerCase() === "text")) {
|
136
|
+
return;
|
137
|
+
}
|
138
|
+
|
139
|
+
if (e.key.toLowerCase() == "w") {
|
140
|
+
document.getElementById("submit_btn").click();
|
141
|
+
} else if (e.key.toLowerCase() == "d") {
|
142
|
+
document.getElementById("next_btn").click();
|
143
|
+
} else if (e.key.toLowerCase() == "a") {
|
144
|
+
document.getElementById("back_btn").click();
|
145
|
+
}
|
146
|
+
}
|
147
|
+
document.addEventListener('keypress', shortcuts, false);
|
148
|
+
</script>
|
149
|
+
"""
|
150
|
+
|
151
|
+
logger.info(f"Launching labeling interface for {len(df)} samples")
|
152
|
+
|
153
|
+
filepaths = df["filepath"].tolist()
|
154
|
+
|
155
|
+
with gr.Blocks(head=shortcut_js) as demo:
|
156
|
+
current_index = gr.State(value=0)
|
157
|
+
|
158
|
+
filename = gr.Textbox(
|
159
|
+
label="Filename", value=filepaths[0], interactive=False
|
160
|
+
)
|
161
|
+
|
162
|
+
image = gr.Image(
|
163
|
+
type="filepath", label="Image", value=filepaths[0], height=500
|
164
|
+
)
|
165
|
+
category = gr.Radio(choices=self.class_names, label="Select Category")
|
166
|
+
|
167
|
+
with gr.Row():
|
168
|
+
back_btn = gr.Button("← Previous (A)", elem_id="back_btn")
|
169
|
+
submit_btn = gr.Button(
|
170
|
+
"Submit (W)",
|
171
|
+
variant="primary",
|
172
|
+
elem_id="submit_btn",
|
173
|
+
interactive=False,
|
174
|
+
)
|
175
|
+
next_btn = gr.Button("Next → (D)", elem_id="next_btn")
|
176
|
+
|
177
|
+
progress = gr.Slider(
|
178
|
+
minimum=0,
|
179
|
+
maximum=len(filepaths) - 1,
|
180
|
+
value=0,
|
181
|
+
label="Progress",
|
182
|
+
interactive=False,
|
183
|
+
)
|
184
|
+
|
185
|
+
finish_btn = gr.Button("Finish Labeling", variant="primary")
|
186
|
+
|
187
|
+
def update_submit_btn(choice):
|
188
|
+
return gr.Button(interactive=choice is not None)
|
189
|
+
|
190
|
+
category.change(
|
191
|
+
fn=update_submit_btn, inputs=[category], outputs=[submit_btn]
|
192
|
+
)
|
193
|
+
|
194
|
+
def navigate(current_idx, direction):
|
195
|
+
next_idx = current_idx + direction
|
196
|
+
if 0 <= next_idx < len(filepaths):
|
197
|
+
return filepaths[next_idx], filepaths[next_idx], next_idx, next_idx
|
198
|
+
return (
|
199
|
+
filepaths[current_idx],
|
200
|
+
filepaths[current_idx],
|
201
|
+
current_idx,
|
202
|
+
current_idx,
|
203
|
+
)
|
204
|
+
|
205
|
+
def save_and_next(current_idx, selected_category):
|
206
|
+
if selected_category is None:
|
207
|
+
return (
|
208
|
+
filepaths[current_idx],
|
209
|
+
filepaths[current_idx],
|
210
|
+
current_idx,
|
211
|
+
current_idx,
|
212
|
+
)
|
213
|
+
|
214
|
+
# Save the current annotation
|
215
|
+
with open(f"{output_filename}.csv", "a") as f:
|
216
|
+
f.write(f"{filepaths[current_idx]},{selected_category}\n")
|
217
|
+
|
218
|
+
# Move to next image if not at the end
|
219
|
+
next_idx = current_idx + 1
|
220
|
+
if next_idx >= len(filepaths):
|
221
|
+
return (
|
222
|
+
filepaths[current_idx],
|
223
|
+
filepaths[current_idx],
|
224
|
+
current_idx,
|
225
|
+
current_idx,
|
226
|
+
)
|
227
|
+
return filepaths[next_idx], filepaths[next_idx], next_idx, next_idx
|
228
|
+
|
229
|
+
def convert_csv_to_parquet():
|
230
|
+
try:
|
231
|
+
df = pd.read_csv(f"{output_filename}.csv", header=None)
|
232
|
+
df.columns = ["filepath", "label"]
|
233
|
+
df = df.drop_duplicates(subset=["filepath"], keep="last")
|
234
|
+
df.to_parquet(f"{output_filename}.parquet")
|
235
|
+
gr.Info(f"Annotation saved to {output_filename}.parquet")
|
236
|
+
except Exception as e:
|
237
|
+
logger.error(e)
|
238
|
+
return
|
239
|
+
|
240
|
+
back_btn.click(
|
241
|
+
fn=lambda idx: navigate(idx, -1),
|
242
|
+
inputs=[current_index],
|
243
|
+
outputs=[filename, image, current_index, progress],
|
244
|
+
)
|
245
|
+
|
246
|
+
next_btn.click(
|
247
|
+
fn=lambda idx: navigate(idx, 1),
|
248
|
+
inputs=[current_index],
|
249
|
+
outputs=[filename, image, current_index, progress],
|
250
|
+
)
|
251
|
+
|
252
|
+
submit_btn.click(
|
253
|
+
fn=save_and_next,
|
254
|
+
inputs=[current_index, category],
|
255
|
+
outputs=[filename, image, current_index, progress],
|
256
|
+
)
|
257
|
+
|
258
|
+
finish_btn.click(fn=convert_csv_to_parquet)
|
259
|
+
|
260
|
+
demo.launch(height=1000)
|
261
|
+
|
262
|
+
def add_to_train_set(self, df: pd.DataFrame, output_filename: str):
|
263
|
+
"""
|
264
|
+
Add samples to the training set.
|
265
|
+
"""
|
266
|
+
new_train_set = df.copy()
|
267
|
+
# new_train_set.drop(columns=["pred_conf"], inplace=True)
|
268
|
+
# new_train_set.rename(columns={"pred_label": "label"}, inplace=True)
|
269
|
+
|
270
|
+
# len_old = len(self.train_set)
|
271
|
+
|
272
|
+
logger.info(f"Adding {len(new_train_set)} samples to training set")
|
273
|
+
self.train_set = pd.concat([self.train_set, new_train_set])
|
274
|
+
|
275
|
+
self.train_set = self.train_set.drop_duplicates(
|
276
|
+
subset=["filepath"], keep="last"
|
277
|
+
)
|
278
|
+
self.train_set.reset_index(drop=True, inplace=True)
|
279
|
+
|
280
|
+
self.train_set.to_parquet(f"{output_filename}.parquet")
|
281
|
+
logger.info(f"Saved training set to {output_filename}.parquet")
|
282
|
+
|
283
|
+
# if len(self.train_set) == len_old:
|
284
|
+
# logger.warning("No new samples added to training set")
|
285
|
+
|
286
|
+
# elif len_old + len(new_train_set) < len(self.train_set):
|
287
|
+
# logger.warning("Some samples were duplicates and removed from training set")
|
288
|
+
|
289
|
+
# else:
|
290
|
+
# logger.info("All new samples added to training set")
|
291
|
+
# logger.info(f"Training set now has {len(self.train_set)} samples")
|
@@ -1,12 +1,13 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: active-vision
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.3
|
4
4
|
Summary: Active learning for edge vision.
|
5
5
|
Requires-Python: >=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
7
7
|
License-File: LICENSE
|
8
8
|
Requires-Dist: datasets>=3.2.0
|
9
9
|
Requires-Dist: fastai>=2.7.18
|
10
|
+
Requires-Dist: gradio>=5.12.0
|
10
11
|
Requires-Dist: ipykernel>=6.29.5
|
11
12
|
Requires-Dist: ipywidgets>=8.1.5
|
12
13
|
Requires-Dist: loguru>=0.7.3
|
@@ -14,75 +15,93 @@ Requires-Dist: seaborn>=0.13.2
|
|
14
15
|
|
15
16
|

|
16
17
|

|
18
|
+
[](https://pypi.org/project/active-vision/)
|
19
|
+

|
17
20
|
|
18
21
|
<p align="center">
|
19
|
-
<img src="
|
22
|
+
<img src="https://raw.githubusercontent.com/dnth/active-vision/main/assets/logo.png" alt="active-vision">
|
20
23
|
</p>
|
21
24
|
|
22
25
|
Active learning at the edge for computer vision.
|
23
26
|
|
24
|
-
The goal of this project is to create a framework for active learning
|
27
|
+
The goal of this project is to create a framework for the active learning loop for computer vision deployed on edge devices.
|
25
28
|
|
26
|
-
##
|
29
|
+
## Installation
|
30
|
+
I recommend using [uv](https://docs.astral.sh/uv/) to set up a virtual environment and install the package. You can also use other virtual env of your choice.
|
27
31
|
|
28
|
-
|
29
|
-
- User interface: streamlit
|
30
|
-
- Database: sqlite
|
31
|
-
- Experiment tracking: wandb
|
32
|
+
If you're using uv:
|
32
33
|
|
33
|
-
|
34
|
+
```bash
|
35
|
+
uv venv
|
36
|
+
uv sync
|
37
|
+
```
|
38
|
+
Once the virtual environment is created, you can install the package using pip.
|
34
39
|
|
35
|
-
PyPI
|
40
|
+
Get a release from PyPI
|
36
41
|
```bash
|
37
42
|
pip install active-vision
|
38
43
|
```
|
39
44
|
|
40
|
-
|
45
|
+
Install from source
|
41
46
|
```bash
|
42
47
|
git clone https://github.com/dnth/active-vision.git
|
43
48
|
cd active-vision
|
44
49
|
pip install -e .
|
45
50
|
```
|
46
51
|
|
47
|
-
|
52
|
+
> [!TIP]
|
53
|
+
> If you're using uv add a uv before the pip install command to install into your virtual environment. Eg:
|
54
|
+
> ```bash
|
55
|
+
> uv pip install active-vision
|
56
|
+
> ```
|
48
57
|
|
49
|
-
|
50
|
-
|
58
|
+
## Usage
|
59
|
+
See the [notebook](./nbs/04_relabel_loop.ipynb) for a complete example.
|
51
60
|
|
52
|
-
|
53
|
-
|
61
|
+
Be sure to prepared 3 datasets:
|
62
|
+
- train: A dataframe of an existing labeled training dataset.
|
63
|
+
- unlabeled: A dataframe of unlabeled data which we will sample from using active learning.
|
64
|
+
- eval: A dataframe of labeled data which we will use to evaluate the performance of the model. (Optional)
|
54
65
|
|
55
|
-
|
56
|
-
|
66
|
+
```python
|
67
|
+
from active_vision import ActiveLearner
|
68
|
+
import pandas as pd
|
69
|
+
|
70
|
+
# Create an active learner instance with a model
|
71
|
+
al = ActiveLearner("resnet18")
|
57
72
|
|
58
|
-
#
|
59
|
-
|
73
|
+
# Load dataset
|
74
|
+
train_df = pd.read_parquet("training_samples.parquet")
|
75
|
+
al.load_dataset(df, filepath_col="filepath", label_col="label")
|
60
76
|
|
61
|
-
# Train
|
62
|
-
|
77
|
+
# Train model
|
78
|
+
al.train(epochs=3, lr=1e-3)
|
63
79
|
|
64
|
-
#
|
65
|
-
|
80
|
+
# Evaluate the model on a *labeled* evaluation set
|
81
|
+
accuracy = al.evaluate(eval_df, filepath_col="filepath", label_col="label")
|
66
82
|
|
67
|
-
#
|
68
|
-
|
83
|
+
# Get predictions from an *unlabeled* set
|
84
|
+
pred_df = al.predict(filepaths)
|
69
85
|
|
70
|
-
#
|
71
|
-
|
86
|
+
# Sample low confidence predictions from unlabeled set
|
87
|
+
uncertain_df = al.sample_uncertain(pred_df, num_samples=10)
|
72
88
|
|
73
|
-
#
|
74
|
-
|
89
|
+
# Launch a Gradio UI to label the low confidence samples
|
90
|
+
al.label(uncertain_df, output_filename="uncertain")
|
91
|
+
```
|
75
92
|
|
76
|
-
|
77
|
-
model.random_sampling()
|
93
|
+

|
78
94
|
|
79
|
-
|
80
|
-
|
95
|
+
Once complete, the labeled samples will be save into a new df.
|
96
|
+
We can now add the newly labeled data to the training set.
|
81
97
|
|
82
|
-
|
83
|
-
|
98
|
+
```python
|
99
|
+
# Add newly labeled data to training set and save as a new file active_labeled
|
100
|
+
al.add_to_train_set(labeled_df, output_filename="active_labeled")
|
84
101
|
```
|
85
102
|
|
103
|
+
Repeat the process until the model is good enough. Use the dataset to train a larger model and deploy.
|
104
|
+
|
86
105
|
## Workflow
|
87
106
|
There are two workflows for active learning at the edge that we can use depending on the availability of labeled data.
|
88
107
|
|
@@ -0,0 +1,7 @@
|
|
1
|
+
active_vision/__init__.py,sha256=hZp8jB284ByY44Q5cdwTt9Zz5n4QWXnz0OexpEE9muk,43
|
2
|
+
active_vision/core.py,sha256=0aXDI5Gpj0Spk7TSIxJf8aJDDBgZh0-jkWdYyZ1Zric,10713
|
3
|
+
active_vision-0.0.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
4
|
+
active_vision-0.0.3.dist-info/METADATA,sha256=g629Kn07n4ZXOOX5cW1nPQK1IR9Mm5vW_z7RqxdwKgY,9385
|
5
|
+
active_vision-0.0.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
6
|
+
active_vision-0.0.3.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
|
7
|
+
active_vision-0.0.3.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
active_vision/__init__.py,sha256=sXLh7g3KC4QCFxcZGBTpG2scR7hmmBsMjq6LqRptkRg,22
|
2
|
-
active_vision-0.0.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
3
|
-
active_vision-0.0.1.dist-info/METADATA,sha256=lPOTTVSPAaX3Rn9Q1ci_jgoQOC-HFpQIyTNqrouOYEs,7936
|
4
|
-
active_vision-0.0.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
5
|
-
active_vision-0.0.1.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
|
6
|
-
active_vision-0.0.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|