active-vision 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- active_vision/__init__.py +3 -1
- active_vision/core.py +149 -0
- {active_vision-0.0.1.dist-info → active_vision-0.0.2.dist-info}/METADATA +24 -28
- active_vision-0.0.2.dist-info/RECORD +7 -0
- active_vision-0.0.1.dist-info/RECORD +0 -6
- {active_vision-0.0.1.dist-info → active_vision-0.0.2.dist-info}/LICENSE +0 -0
- {active_vision-0.0.1.dist-info → active_vision-0.0.2.dist-info}/WHEEL +0 -0
- {active_vision-0.0.1.dist-info → active_vision-0.0.2.dist-info}/top_level.txt +0 -0
active_vision/__init__.py
CHANGED
active_vision/core.py
ADDED
@@ -0,0 +1,149 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
from loguru import logger
|
3
|
+
from fastai.vision.models import resnet18, resnet34
|
4
|
+
from fastai.callback.all import ShowGraphCallback
|
5
|
+
from fastai.vision.all import (
|
6
|
+
ImageDataLoaders,
|
7
|
+
aug_transforms,
|
8
|
+
Resize,
|
9
|
+
vision_learner,
|
10
|
+
accuracy,
|
11
|
+
valley,
|
12
|
+
slide,
|
13
|
+
minimum,
|
14
|
+
steep,
|
15
|
+
)
|
16
|
+
import torch
|
17
|
+
import torch.nn.functional as F
|
18
|
+
|
19
|
+
import warnings
|
20
|
+
|
21
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
22
|
+
|
23
|
+
|
24
|
+
class ActiveLearner:
|
25
|
+
def __init__(self, model_name: str):
|
26
|
+
self.model = self.load_model(model_name)
|
27
|
+
|
28
|
+
def load_model(self, model_name: str):
|
29
|
+
models = {"resnet18": resnet18, "resnet34": resnet34}
|
30
|
+
logger.info(f"Loading model {model_name}")
|
31
|
+
if model_name not in models:
|
32
|
+
logger.error(f"Model {model_name} not found")
|
33
|
+
raise ValueError(f"Model {model_name} not found")
|
34
|
+
return models[model_name]
|
35
|
+
|
36
|
+
def load_dataset(
|
37
|
+
self,
|
38
|
+
df: pd.DataFrame,
|
39
|
+
filepath_col: str,
|
40
|
+
label_col: str,
|
41
|
+
valid_pct: float = 0.2,
|
42
|
+
batch_size: int = 16,
|
43
|
+
image_size: int = 224,
|
44
|
+
):
|
45
|
+
logger.info(f"Loading dataset from {filepath_col} and {label_col}")
|
46
|
+
self.train_set = df.copy()
|
47
|
+
|
48
|
+
logger.info("Creating dataloaders")
|
49
|
+
self.dls = ImageDataLoaders.from_df(
|
50
|
+
df,
|
51
|
+
path=".",
|
52
|
+
valid_pct=valid_pct,
|
53
|
+
fn_col=filepath_col,
|
54
|
+
label_col=label_col,
|
55
|
+
bs=batch_size,
|
56
|
+
item_tfms=Resize(image_size),
|
57
|
+
batch_tfms=aug_transforms(size=image_size, min_scale=0.75),
|
58
|
+
)
|
59
|
+
logger.info("Creating learner")
|
60
|
+
self.learn = vision_learner(self.dls, self.model, metrics=accuracy).to_fp16()
|
61
|
+
self.class_names = self.dls.vocab
|
62
|
+
logger.info("Done. Ready to train.")
|
63
|
+
|
64
|
+
def lr_find(self):
|
65
|
+
logger.info("Finding optimal learning rate")
|
66
|
+
self.lrs = self.learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
|
67
|
+
logger.info(f"Optimal learning rate: {self.lrs.valley}")
|
68
|
+
|
69
|
+
def train(self, epochs: int, lr: float):
|
70
|
+
logger.info(f"Training for {epochs} epochs with learning rate: {lr}")
|
71
|
+
self.learn.fine_tune(epochs, lr, cbs=[ShowGraphCallback()])
|
72
|
+
|
73
|
+
def predict(self, filepaths: list[str], batch_size: int = 16):
|
74
|
+
"""
|
75
|
+
Run inference on an unlabeled dataset. Returns a df with filepaths and predicted labels, and confidence scores.
|
76
|
+
"""
|
77
|
+
logger.info(f"Running inference on {len(filepaths)} samples")
|
78
|
+
test_dl = self.dls.test_dl(filepaths, bs=batch_size)
|
79
|
+
preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
|
80
|
+
|
81
|
+
self.pred_df = pd.DataFrame(
|
82
|
+
{
|
83
|
+
"filepath": filepaths,
|
84
|
+
"pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
|
85
|
+
"pred_conf": torch.max(F.softmax(preds, dim=1), dim=1)[0].numpy(),
|
86
|
+
}
|
87
|
+
)
|
88
|
+
return self.pred_df
|
89
|
+
|
90
|
+
def evaluate(self, df: pd.DataFrame, filepath_col: str, label_col: str, batch_size: int = 16):
|
91
|
+
"""
|
92
|
+
Evaluate on a labeled dataset. Returns a score.
|
93
|
+
"""
|
94
|
+
self.eval_set = df.copy()
|
95
|
+
|
96
|
+
filepaths = self.eval_set[filepath_col].tolist()
|
97
|
+
labels = self.eval_set[label_col].tolist()
|
98
|
+
test_dl = self.dls.test_dl(filepaths, bs=batch_size)
|
99
|
+
preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
|
100
|
+
|
101
|
+
self.eval_df = pd.DataFrame(
|
102
|
+
{
|
103
|
+
"filepath": filepaths,
|
104
|
+
"label": labels,
|
105
|
+
"pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
|
106
|
+
}
|
107
|
+
)
|
108
|
+
|
109
|
+
accuracy = float((self.eval_df["label"] == self.eval_df["pred_label"]).mean())
|
110
|
+
logger.info(f"Accuracy: {accuracy:.2%}")
|
111
|
+
return accuracy
|
112
|
+
|
113
|
+
def sample_uncertain(self, df: pd.DataFrame, num_samples: int):
|
114
|
+
"""
|
115
|
+
Sample top `num_samples` low confidence samples. Returns a df with filepaths and predicted labels, and confidence scores.
|
116
|
+
"""
|
117
|
+
uncertain_df = df.sort_values(
|
118
|
+
by="pred_conf", ascending=True
|
119
|
+
).head(num_samples)
|
120
|
+
return uncertain_df
|
121
|
+
|
122
|
+
def add_to_train_set(self, df: pd.DataFrame):
|
123
|
+
"""
|
124
|
+
Add samples to the training set.
|
125
|
+
"""
|
126
|
+
new_train_set = df.copy()
|
127
|
+
new_train_set.drop(columns=["pred_conf"], inplace=True)
|
128
|
+
new_train_set.rename(columns={"pred_label": "label"}, inplace=True)
|
129
|
+
|
130
|
+
len_old = len(self.train_set)
|
131
|
+
|
132
|
+
logger.info(f"Adding {len(new_train_set)} samples to training set")
|
133
|
+
self.train_set = pd.concat([self.train_set, new_train_set])
|
134
|
+
|
135
|
+
self.train_set = self.train_set.drop_duplicates(
|
136
|
+
subset=["filepath"], keep="last"
|
137
|
+
)
|
138
|
+
self.train_set.reset_index(drop=True, inplace=True)
|
139
|
+
|
140
|
+
|
141
|
+
if len(self.train_set) == len_old:
|
142
|
+
logger.warning("No new samples added to training set")
|
143
|
+
|
144
|
+
elif len_old + len(new_train_set) < len(self.train_set):
|
145
|
+
logger.warning("Some samples were duplicates and removed from training set")
|
146
|
+
|
147
|
+
else:
|
148
|
+
logger.info("All new samples added to training set")
|
149
|
+
logger.info(f"Training set now has {len(self.train_set)} samples")
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: active-vision
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.2
|
4
4
|
Summary: Active learning for edge vision.
|
5
5
|
Requires-Python: >=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
@@ -14,9 +14,11 @@ Requires-Dist: seaborn>=0.13.2
|
|
14
14
|
|
15
15
|

|
16
16
|

|
17
|
+

|
18
|
+

|
17
19
|
|
18
20
|
<p align="center">
|
19
|
-
<img src="
|
21
|
+
<img src="https://github.com/dnth/active-vision/blob/main/assets/logo.png" alt="active-vision">
|
20
22
|
</p>
|
21
23
|
|
22
24
|
Active learning at the edge for computer vision.
|
@@ -44,43 +46,37 @@ cd active-vision
|
|
44
46
|
pip install -e .
|
45
47
|
```
|
46
48
|
|
47
|
-
## Usage
|
49
|
+
## Usage
|
50
|
+
See the [notebook](./nbs/end-to-end.ipynb) for a complete example.
|
48
51
|
|
49
52
|
```python
|
50
|
-
|
53
|
+
from active_vision import ActiveLearner
|
54
|
+
import pandas as pd
|
51
55
|
|
52
|
-
#
|
53
|
-
|
56
|
+
# Create an active learner instance with a model
|
57
|
+
al = ActiveLearner("resnet18")
|
54
58
|
|
55
|
-
# Load
|
56
|
-
|
57
|
-
|
58
|
-
# Inital sampling
|
59
|
-
dataset = av.initial_sampling(dataset, n_samples=10)
|
59
|
+
# Load the dataset into the active learner
|
60
|
+
train_df = pd.read_parquet("training_samples.parquet")
|
61
|
+
al.load_dataset(train_df, "filepath", "label")
|
60
62
|
|
61
63
|
# Train the model
|
62
|
-
|
63
|
-
|
64
|
-
# Save the model
|
65
|
-
model.save()
|
66
|
-
|
67
|
-
# Evaluate the model
|
68
|
-
model.evaluate(df)
|
64
|
+
al.train(epochs=3, lr=1e-3)
|
69
65
|
|
70
|
-
#
|
71
|
-
|
66
|
+
# Load evaluation data
|
67
|
+
eval_df = pd.read_parquet("evaluation_samples.parquet")
|
72
68
|
|
73
|
-
#
|
74
|
-
|
69
|
+
# Evaluate the model on a labeled evaluation set
|
70
|
+
accuracy = al.evaluate(eval_df, "filepath", "label")
|
75
71
|
|
76
|
-
#
|
77
|
-
|
72
|
+
# Get predictions from an unlabeled set
|
73
|
+
pred_df = al.predict(filepaths)
|
78
74
|
|
79
|
-
#
|
80
|
-
|
75
|
+
# Sample low confidence predictions
|
76
|
+
uncertain_df = al.sample_uncertain(pred_df, num_samples=10)
|
81
77
|
|
82
|
-
#
|
83
|
-
|
78
|
+
# Add newly labeled data to training set
|
79
|
+
al.add_to_train_set(uncertain_df)
|
84
80
|
```
|
85
81
|
|
86
82
|
## Workflow
|
@@ -0,0 +1,7 @@
|
|
1
|
+
active_vision/__init__.py,sha256=5VE_DRQ_Rgbo7NlPh3-rP2pUClK48jGxPqAcptBscZ8,43
|
2
|
+
active_vision/core.py,sha256=RBVabC350wucYl7KJgIp3fc1pS9pxtG14iDb-ZyBJxI,5262
|
3
|
+
active_vision-0.0.2.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
4
|
+
active_vision-0.0.2.dist-info/METADATA,sha256=7_eqZJnGeIPjb4LLZ-Bqu1AMJ_h77_0bNRyS_COEv5w,8350
|
5
|
+
active_vision-0.0.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
6
|
+
active_vision-0.0.2.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
|
7
|
+
active_vision-0.0.2.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
active_vision/__init__.py,sha256=sXLh7g3KC4QCFxcZGBTpG2scR7hmmBsMjq6LqRptkRg,22
|
2
|
-
active_vision-0.0.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
3
|
-
active_vision-0.0.1.dist-info/METADATA,sha256=lPOTTVSPAaX3Rn9Q1ci_jgoQOC-HFpQIyTNqrouOYEs,7936
|
4
|
-
active_vision-0.0.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
5
|
-
active_vision-0.0.1.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
|
6
|
-
active_vision-0.0.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|