active-vision 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- active_vision/__init__.py +3 -1
- active_vision/core.py +149 -0
- {active_vision-0.0.1.dist-info → active_vision-0.0.2.dist-info}/METADATA +24 -28
- active_vision-0.0.2.dist-info/RECORD +7 -0
- active_vision-0.0.1.dist-info/RECORD +0 -6
- {active_vision-0.0.1.dist-info → active_vision-0.0.2.dist-info}/LICENSE +0 -0
- {active_vision-0.0.1.dist-info → active_vision-0.0.2.dist-info}/WHEEL +0 -0
- {active_vision-0.0.1.dist-info → active_vision-0.0.2.dist-info}/top_level.txt +0 -0
active_vision/__init__.py
CHANGED
active_vision/core.py
ADDED
@@ -0,0 +1,149 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
from loguru import logger
|
3
|
+
from fastai.vision.models import resnet18, resnet34
|
4
|
+
from fastai.callback.all import ShowGraphCallback
|
5
|
+
from fastai.vision.all import (
|
6
|
+
ImageDataLoaders,
|
7
|
+
aug_transforms,
|
8
|
+
Resize,
|
9
|
+
vision_learner,
|
10
|
+
accuracy,
|
11
|
+
valley,
|
12
|
+
slide,
|
13
|
+
minimum,
|
14
|
+
steep,
|
15
|
+
)
|
16
|
+
import torch
|
17
|
+
import torch.nn.functional as F
|
18
|
+
|
19
|
+
import warnings
|
20
|
+
|
21
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
22
|
+
|
23
|
+
|
24
|
+
class ActiveLearner:
|
25
|
+
def __init__(self, model_name: str):
|
26
|
+
self.model = self.load_model(model_name)
|
27
|
+
|
28
|
+
def load_model(self, model_name: str):
|
29
|
+
models = {"resnet18": resnet18, "resnet34": resnet34}
|
30
|
+
logger.info(f"Loading model {model_name}")
|
31
|
+
if model_name not in models:
|
32
|
+
logger.error(f"Model {model_name} not found")
|
33
|
+
raise ValueError(f"Model {model_name} not found")
|
34
|
+
return models[model_name]
|
35
|
+
|
36
|
+
def load_dataset(
|
37
|
+
self,
|
38
|
+
df: pd.DataFrame,
|
39
|
+
filepath_col: str,
|
40
|
+
label_col: str,
|
41
|
+
valid_pct: float = 0.2,
|
42
|
+
batch_size: int = 16,
|
43
|
+
image_size: int = 224,
|
44
|
+
):
|
45
|
+
logger.info(f"Loading dataset from {filepath_col} and {label_col}")
|
46
|
+
self.train_set = df.copy()
|
47
|
+
|
48
|
+
logger.info("Creating dataloaders")
|
49
|
+
self.dls = ImageDataLoaders.from_df(
|
50
|
+
df,
|
51
|
+
path=".",
|
52
|
+
valid_pct=valid_pct,
|
53
|
+
fn_col=filepath_col,
|
54
|
+
label_col=label_col,
|
55
|
+
bs=batch_size,
|
56
|
+
item_tfms=Resize(image_size),
|
57
|
+
batch_tfms=aug_transforms(size=image_size, min_scale=0.75),
|
58
|
+
)
|
59
|
+
logger.info("Creating learner")
|
60
|
+
self.learn = vision_learner(self.dls, self.model, metrics=accuracy).to_fp16()
|
61
|
+
self.class_names = self.dls.vocab
|
62
|
+
logger.info("Done. Ready to train.")
|
63
|
+
|
64
|
+
def lr_find(self):
|
65
|
+
logger.info("Finding optimal learning rate")
|
66
|
+
self.lrs = self.learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
|
67
|
+
logger.info(f"Optimal learning rate: {self.lrs.valley}")
|
68
|
+
|
69
|
+
def train(self, epochs: int, lr: float):
|
70
|
+
logger.info(f"Training for {epochs} epochs with learning rate: {lr}")
|
71
|
+
self.learn.fine_tune(epochs, lr, cbs=[ShowGraphCallback()])
|
72
|
+
|
73
|
+
def predict(self, filepaths: list[str], batch_size: int = 16):
|
74
|
+
"""
|
75
|
+
Run inference on an unlabeled dataset. Returns a df with filepaths and predicted labels, and confidence scores.
|
76
|
+
"""
|
77
|
+
logger.info(f"Running inference on {len(filepaths)} samples")
|
78
|
+
test_dl = self.dls.test_dl(filepaths, bs=batch_size)
|
79
|
+
preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
|
80
|
+
|
81
|
+
self.pred_df = pd.DataFrame(
|
82
|
+
{
|
83
|
+
"filepath": filepaths,
|
84
|
+
"pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
|
85
|
+
"pred_conf": torch.max(F.softmax(preds, dim=1), dim=1)[0].numpy(),
|
86
|
+
}
|
87
|
+
)
|
88
|
+
return self.pred_df
|
89
|
+
|
90
|
+
def evaluate(self, df: pd.DataFrame, filepath_col: str, label_col: str, batch_size: int = 16):
|
91
|
+
"""
|
92
|
+
Evaluate on a labeled dataset. Returns a score.
|
93
|
+
"""
|
94
|
+
self.eval_set = df.copy()
|
95
|
+
|
96
|
+
filepaths = self.eval_set[filepath_col].tolist()
|
97
|
+
labels = self.eval_set[label_col].tolist()
|
98
|
+
test_dl = self.dls.test_dl(filepaths, bs=batch_size)
|
99
|
+
preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
|
100
|
+
|
101
|
+
self.eval_df = pd.DataFrame(
|
102
|
+
{
|
103
|
+
"filepath": filepaths,
|
104
|
+
"label": labels,
|
105
|
+
"pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
|
106
|
+
}
|
107
|
+
)
|
108
|
+
|
109
|
+
accuracy = float((self.eval_df["label"] == self.eval_df["pred_label"]).mean())
|
110
|
+
logger.info(f"Accuracy: {accuracy:.2%}")
|
111
|
+
return accuracy
|
112
|
+
|
113
|
+
def sample_uncertain(self, df: pd.DataFrame, num_samples: int):
|
114
|
+
"""
|
115
|
+
Sample top `num_samples` low confidence samples. Returns a df with filepaths and predicted labels, and confidence scores.
|
116
|
+
"""
|
117
|
+
uncertain_df = df.sort_values(
|
118
|
+
by="pred_conf", ascending=True
|
119
|
+
).head(num_samples)
|
120
|
+
return uncertain_df
|
121
|
+
|
122
|
+
def add_to_train_set(self, df: pd.DataFrame):
|
123
|
+
"""
|
124
|
+
Add samples to the training set.
|
125
|
+
"""
|
126
|
+
new_train_set = df.copy()
|
127
|
+
new_train_set.drop(columns=["pred_conf"], inplace=True)
|
128
|
+
new_train_set.rename(columns={"pred_label": "label"}, inplace=True)
|
129
|
+
|
130
|
+
len_old = len(self.train_set)
|
131
|
+
|
132
|
+
logger.info(f"Adding {len(new_train_set)} samples to training set")
|
133
|
+
self.train_set = pd.concat([self.train_set, new_train_set])
|
134
|
+
|
135
|
+
self.train_set = self.train_set.drop_duplicates(
|
136
|
+
subset=["filepath"], keep="last"
|
137
|
+
)
|
138
|
+
self.train_set.reset_index(drop=True, inplace=True)
|
139
|
+
|
140
|
+
|
141
|
+
if len(self.train_set) == len_old:
|
142
|
+
logger.warning("No new samples added to training set")
|
143
|
+
|
144
|
+
elif len_old + len(new_train_set) < len(self.train_set):
|
145
|
+
logger.warning("Some samples were duplicates and removed from training set")
|
146
|
+
|
147
|
+
else:
|
148
|
+
logger.info("All new samples added to training set")
|
149
|
+
logger.info(f"Training set now has {len(self.train_set)} samples")
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: active-vision
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.2
|
4
4
|
Summary: Active learning for edge vision.
|
5
5
|
Requires-Python: >=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
@@ -14,9 +14,11 @@ Requires-Dist: seaborn>=0.13.2
|
|
14
14
|
|
15
15
|
![Python Version](https://img.shields.io/badge/python-3.10%2B-blue?style=for-the-badge)
|
16
16
|
![License](https://img.shields.io/badge/License-Apache%202.0-green.svg?style=for-the-badge)
|
17
|
+
![PyPI](https://img.shields.io/pypi/v/active-vision?style=for-the-badge)
|
18
|
+
![Downloads](https://img.shields.io/pepy/dt/active-vision?style=for-the-badge&logo=pypi&logoColor=white&label=Downloads&color=purple)
|
17
19
|
|
18
20
|
<p align="center">
|
19
|
-
<img src="
|
21
|
+
<img src="https://github.com/dnth/active-vision/blob/main/assets/logo.png" alt="active-vision">
|
20
22
|
</p>
|
21
23
|
|
22
24
|
Active learning at the edge for computer vision.
|
@@ -44,43 +46,37 @@ cd active-vision
|
|
44
46
|
pip install -e .
|
45
47
|
```
|
46
48
|
|
47
|
-
## Usage
|
49
|
+
## Usage
|
50
|
+
See the [notebook](./nbs/end-to-end.ipynb) for a complete example.
|
48
51
|
|
49
52
|
```python
|
50
|
-
|
53
|
+
from active_vision import ActiveLearner
|
54
|
+
import pandas as pd
|
51
55
|
|
52
|
-
#
|
53
|
-
|
56
|
+
# Create an active learner instance with a model
|
57
|
+
al = ActiveLearner("resnet18")
|
54
58
|
|
55
|
-
# Load
|
56
|
-
|
57
|
-
|
58
|
-
# Inital sampling
|
59
|
-
dataset = av.initial_sampling(dataset, n_samples=10)
|
59
|
+
# Load the dataset into the active learner
|
60
|
+
train_df = pd.read_parquet("training_samples.parquet")
|
61
|
+
al.load_dataset(train_df, "filepath", "label")
|
60
62
|
|
61
63
|
# Train the model
|
62
|
-
|
63
|
-
|
64
|
-
# Save the model
|
65
|
-
model.save()
|
66
|
-
|
67
|
-
# Evaluate the model
|
68
|
-
model.evaluate(df)
|
64
|
+
al.train(epochs=3, lr=1e-3)
|
69
65
|
|
70
|
-
#
|
71
|
-
|
66
|
+
# Load evaluation data
|
67
|
+
eval_df = pd.read_parquet("evaluation_samples.parquet")
|
72
68
|
|
73
|
-
#
|
74
|
-
|
69
|
+
# Evaluate the model on a labeled evaluation set
|
70
|
+
accuracy = al.evaluate(eval_df, "filepath", "label")
|
75
71
|
|
76
|
-
#
|
77
|
-
|
72
|
+
# Get predictions from an unlabeled set
|
73
|
+
pred_df = al.predict(filepaths)
|
78
74
|
|
79
|
-
#
|
80
|
-
|
75
|
+
# Sample low confidence predictions
|
76
|
+
uncertain_df = al.sample_uncertain(pred_df, num_samples=10)
|
81
77
|
|
82
|
-
#
|
83
|
-
|
78
|
+
# Add newly labeled data to training set
|
79
|
+
al.add_to_train_set(uncertain_df)
|
84
80
|
```
|
85
81
|
|
86
82
|
## Workflow
|
@@ -0,0 +1,7 @@
|
|
1
|
+
active_vision/__init__.py,sha256=5VE_DRQ_Rgbo7NlPh3-rP2pUClK48jGxPqAcptBscZ8,43
|
2
|
+
active_vision/core.py,sha256=RBVabC350wucYl7KJgIp3fc1pS9pxtG14iDb-ZyBJxI,5262
|
3
|
+
active_vision-0.0.2.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
4
|
+
active_vision-0.0.2.dist-info/METADATA,sha256=7_eqZJnGeIPjb4LLZ-Bqu1AMJ_h77_0bNRyS_COEv5w,8350
|
5
|
+
active_vision-0.0.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
6
|
+
active_vision-0.0.2.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
|
7
|
+
active_vision-0.0.2.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
active_vision/__init__.py,sha256=sXLh7g3KC4QCFxcZGBTpG2scR7hmmBsMjq6LqRptkRg,22
|
2
|
-
active_vision-0.0.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
3
|
-
active_vision-0.0.1.dist-info/METADATA,sha256=lPOTTVSPAaX3Rn9Q1ci_jgoQOC-HFpQIyTNqrouOYEs,7936
|
4
|
-
active_vision-0.0.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
5
|
-
active_vision-0.0.1.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
|
6
|
-
active_vision-0.0.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|