active-vision 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- active_vision/__init__.py +3 -1
 - active_vision/core.py +291 -0
 - {active_vision-0.0.1.dist-info → active_vision-0.0.3.dist-info}/METADATA +55 -36
 - active_vision-0.0.3.dist-info/RECORD +7 -0
 - active_vision-0.0.1.dist-info/RECORD +0 -6
 - {active_vision-0.0.1.dist-info → active_vision-0.0.3.dist-info}/LICENSE +0 -0
 - {active_vision-0.0.1.dist-info → active_vision-0.0.3.dist-info}/WHEEL +0 -0
 - {active_vision-0.0.1.dist-info → active_vision-0.0.3.dist-info}/top_level.txt +0 -0
 
    
        active_vision/__init__.py
    CHANGED
    
    
    
        active_vision/core.py
    ADDED
    
    | 
         @@ -0,0 +1,291 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import pandas as pd
         
     | 
| 
      
 2 
     | 
    
         
            +
            from loguru import logger
         
     | 
| 
      
 3 
     | 
    
         
            +
            from fastai.vision.models import resnet18, resnet34
         
     | 
| 
      
 4 
     | 
    
         
            +
            from fastai.callback.all import ShowGraphCallback
         
     | 
| 
      
 5 
     | 
    
         
            +
            from fastai.vision.all import (
         
     | 
| 
      
 6 
     | 
    
         
            +
                ImageDataLoaders,
         
     | 
| 
      
 7 
     | 
    
         
            +
                aug_transforms,
         
     | 
| 
      
 8 
     | 
    
         
            +
                Resize,
         
     | 
| 
      
 9 
     | 
    
         
            +
                vision_learner,
         
     | 
| 
      
 10 
     | 
    
         
            +
                accuracy,
         
     | 
| 
      
 11 
     | 
    
         
            +
                valley,
         
     | 
| 
      
 12 
     | 
    
         
            +
                slide,
         
     | 
| 
      
 13 
     | 
    
         
            +
                minimum,
         
     | 
| 
      
 14 
     | 
    
         
            +
                steep,
         
     | 
| 
      
 15 
     | 
    
         
            +
            )
         
     | 
| 
      
 16 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 17 
     | 
    
         
            +
            import torch.nn.functional as F
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            import warnings
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            warnings.filterwarnings("ignore", category=FutureWarning)
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
            class ActiveLearner:
         
     | 
| 
      
 25 
     | 
    
         
            +
                def __init__(self, model_name: str):
         
     | 
| 
      
 26 
     | 
    
         
            +
                    self.model = self.load_model(model_name)
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
                def load_model(self, model_name: str):
         
     | 
| 
      
 29 
     | 
    
         
            +
                    models = {"resnet18": resnet18, "resnet34": resnet34}
         
     | 
| 
      
 30 
     | 
    
         
            +
                    logger.info(f"Loading model {model_name}")
         
     | 
| 
      
 31 
     | 
    
         
            +
                    if model_name not in models:
         
     | 
| 
      
 32 
     | 
    
         
            +
                        logger.error(f"Model {model_name} not found")
         
     | 
| 
      
 33 
     | 
    
         
            +
                        raise ValueError(f"Model {model_name} not found")
         
     | 
| 
      
 34 
     | 
    
         
            +
                    return models[model_name]
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                def load_dataset(
         
     | 
| 
      
 37 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 38 
     | 
    
         
            +
                    df: pd.DataFrame,
         
     | 
| 
      
 39 
     | 
    
         
            +
                    filepath_col: str,
         
     | 
| 
      
 40 
     | 
    
         
            +
                    label_col: str,
         
     | 
| 
      
 41 
     | 
    
         
            +
                    valid_pct: float = 0.2,
         
     | 
| 
      
 42 
     | 
    
         
            +
                    batch_size: int = 16,
         
     | 
| 
      
 43 
     | 
    
         
            +
                    image_size: int = 224,
         
     | 
| 
      
 44 
     | 
    
         
            +
                ):
         
     | 
| 
      
 45 
     | 
    
         
            +
                    logger.info(f"Loading dataset from {filepath_col} and {label_col}")
         
     | 
| 
      
 46 
     | 
    
         
            +
                    self.train_set = df.copy()
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
                    logger.info("Creating dataloaders")
         
     | 
| 
      
 49 
     | 
    
         
            +
                    self.dls = ImageDataLoaders.from_df(
         
     | 
| 
      
 50 
     | 
    
         
            +
                        df,
         
     | 
| 
      
 51 
     | 
    
         
            +
                        path=".",
         
     | 
| 
      
 52 
     | 
    
         
            +
                        valid_pct=valid_pct,
         
     | 
| 
      
 53 
     | 
    
         
            +
                        fn_col=filepath_col,
         
     | 
| 
      
 54 
     | 
    
         
            +
                        label_col=label_col,
         
     | 
| 
      
 55 
     | 
    
         
            +
                        bs=batch_size,
         
     | 
| 
      
 56 
     | 
    
         
            +
                        item_tfms=Resize(image_size),
         
     | 
| 
      
 57 
     | 
    
         
            +
                        batch_tfms=aug_transforms(size=image_size, min_scale=0.75),
         
     | 
| 
      
 58 
     | 
    
         
            +
                    )
         
     | 
| 
      
 59 
     | 
    
         
            +
                    logger.info("Creating learner")
         
     | 
| 
      
 60 
     | 
    
         
            +
                    self.learn = vision_learner(self.dls, self.model, metrics=accuracy).to_fp16()
         
     | 
| 
      
 61 
     | 
    
         
            +
                    self.class_names = self.dls.vocab
         
     | 
| 
      
 62 
     | 
    
         
            +
                    logger.info("Done. Ready to train.")
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                def lr_find(self):
         
     | 
| 
      
 65 
     | 
    
         
            +
                    logger.info("Finding optimal learning rate")
         
     | 
| 
      
 66 
     | 
    
         
            +
                    self.lrs = self.learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
         
     | 
| 
      
 67 
     | 
    
         
            +
                    logger.info(f"Optimal learning rate: {self.lrs.valley}")
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                def train(self, epochs: int, lr: float):
         
     | 
| 
      
 70 
     | 
    
         
            +
                    logger.info(f"Training for {epochs} epochs with learning rate: {lr}")
         
     | 
| 
      
 71 
     | 
    
         
            +
                    self.learn.fine_tune(epochs, lr, cbs=[ShowGraphCallback()])
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                def predict(self, filepaths: list[str], batch_size: int = 16):
         
     | 
| 
      
 74 
     | 
    
         
            +
                    """
         
     | 
| 
      
 75 
     | 
    
         
            +
                    Run inference on an unlabeled dataset. Returns a df with filepaths and predicted labels, and confidence scores.
         
     | 
| 
      
 76 
     | 
    
         
            +
                    """
         
     | 
| 
      
 77 
     | 
    
         
            +
                    logger.info(f"Running inference on {len(filepaths)} samples")
         
     | 
| 
      
 78 
     | 
    
         
            +
                    test_dl = self.dls.test_dl(filepaths, bs=batch_size)
         
     | 
| 
      
 79 
     | 
    
         
            +
                    preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                    self.pred_df = pd.DataFrame(
         
     | 
| 
      
 82 
     | 
    
         
            +
                        {
         
     | 
| 
      
 83 
     | 
    
         
            +
                            "filepath": filepaths,
         
     | 
| 
      
 84 
     | 
    
         
            +
                            "pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
         
     | 
| 
      
 85 
     | 
    
         
            +
                            "pred_conf": torch.max(F.softmax(preds, dim=1), dim=1)[0].numpy(),
         
     | 
| 
      
 86 
     | 
    
         
            +
                        }
         
     | 
| 
      
 87 
     | 
    
         
            +
                    )
         
     | 
| 
      
 88 
     | 
    
         
            +
                    return self.pred_df
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                def evaluate(
         
     | 
| 
      
 91 
     | 
    
         
            +
                    self, df: pd.DataFrame, filepath_col: str, label_col: str, batch_size: int = 16
         
     | 
| 
      
 92 
     | 
    
         
            +
                ):
         
     | 
| 
      
 93 
     | 
    
         
            +
                    """
         
     | 
| 
      
 94 
     | 
    
         
            +
                    Evaluate on a labeled dataset. Returns a score.
         
     | 
| 
      
 95 
     | 
    
         
            +
                    """
         
     | 
| 
      
 96 
     | 
    
         
            +
                    self.eval_set = df.copy()
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                    filepaths = self.eval_set[filepath_col].tolist()
         
     | 
| 
      
 99 
     | 
    
         
            +
                    labels = self.eval_set[label_col].tolist()
         
     | 
| 
      
 100 
     | 
    
         
            +
                    test_dl = self.dls.test_dl(filepaths, bs=batch_size)
         
     | 
| 
      
 101 
     | 
    
         
            +
                    preds, _, cls_preds = self.learn.get_preds(dl=test_dl, with_decoded=True)
         
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
                    self.eval_df = pd.DataFrame(
         
     | 
| 
      
 104 
     | 
    
         
            +
                        {
         
     | 
| 
      
 105 
     | 
    
         
            +
                            "filepath": filepaths,
         
     | 
| 
      
 106 
     | 
    
         
            +
                            "label": labels,
         
     | 
| 
      
 107 
     | 
    
         
            +
                            "pred_label": [self.learn.dls.vocab[i] for i in cls_preds.numpy()],
         
     | 
| 
      
 108 
     | 
    
         
            +
                        }
         
     | 
| 
      
 109 
     | 
    
         
            +
                    )
         
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
      
 111 
     | 
    
         
            +
                    accuracy = float((self.eval_df["label"] == self.eval_df["pred_label"]).mean())
         
     | 
| 
      
 112 
     | 
    
         
            +
                    logger.info(f"Accuracy: {accuracy:.2%}")
         
     | 
| 
      
 113 
     | 
    
         
            +
                    return accuracy
         
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
                def sample_uncertain(self, df: pd.DataFrame, num_samples: int):
         
     | 
| 
      
 116 
     | 
    
         
            +
                    """
         
     | 
| 
      
 117 
     | 
    
         
            +
                    Sample top `num_samples` low confidence samples. Returns a df with filepaths and predicted labels, and confidence scores.
         
     | 
| 
      
 118 
     | 
    
         
            +
                    """
         
     | 
| 
      
 119 
     | 
    
         
            +
                    logger.info(f"Getting top {num_samples} low confidence samples")
         
     | 
| 
      
 120 
     | 
    
         
            +
                    uncertain_df = df.sort_values(by="pred_conf", ascending=True).head(num_samples)
         
     | 
| 
      
 121 
     | 
    
         
            +
                    return uncertain_df
         
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
      
 123 
     | 
    
         
            +
                def label(self, df: pd.DataFrame, output_filename: str = "labeled"):
         
     | 
| 
      
 124 
     | 
    
         
            +
                    """
         
     | 
| 
      
 125 
     | 
    
         
            +
                    Launch a labeling interface for the user to label the samples.
         
     | 
| 
      
 126 
     | 
    
         
            +
                    Input is a df with filepaths listing the files to be labeled. Output is a df with filepaths and labels.
         
     | 
| 
      
 127 
     | 
    
         
            +
                    """
         
     | 
| 
      
 128 
     | 
    
         
            +
                    import gradio as gr
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
                    shortcut_js = """
         
     | 
| 
      
 131 
     | 
    
         
            +
                    <script>
         
     | 
| 
      
 132 
     | 
    
         
            +
                    function shortcuts(e) {
         
     | 
| 
      
 133 
     | 
    
         
            +
                        // Only block shortcuts if we're in a text input or textarea
         
     | 
| 
      
 134 
     | 
    
         
            +
                        if (e.target.tagName.toLowerCase() === "textarea" || 
         
     | 
| 
      
 135 
     | 
    
         
            +
                            (e.target.tagName.toLowerCase() === "input" && e.target.type.toLowerCase() === "text")) {
         
     | 
| 
      
 136 
     | 
    
         
            +
                            return;
         
     | 
| 
      
 137 
     | 
    
         
            +
                        }
         
     | 
| 
      
 138 
     | 
    
         
            +
                        
         
     | 
| 
      
 139 
     | 
    
         
            +
                        if (e.key.toLowerCase() == "w") {
         
     | 
| 
      
 140 
     | 
    
         
            +
                            document.getElementById("submit_btn").click();
         
     | 
| 
      
 141 
     | 
    
         
            +
                        } else if (e.key.toLowerCase() == "d") {
         
     | 
| 
      
 142 
     | 
    
         
            +
                            document.getElementById("next_btn").click();
         
     | 
| 
      
 143 
     | 
    
         
            +
                        } else if (e.key.toLowerCase() == "a") {
         
     | 
| 
      
 144 
     | 
    
         
            +
                            document.getElementById("back_btn").click();
         
     | 
| 
      
 145 
     | 
    
         
            +
                        }
         
     | 
| 
      
 146 
     | 
    
         
            +
                    }
         
     | 
| 
      
 147 
     | 
    
         
            +
                    document.addEventListener('keypress', shortcuts, false);
         
     | 
| 
      
 148 
     | 
    
         
            +
                    </script>
         
     | 
| 
      
 149 
     | 
    
         
            +
                    """
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
                    logger.info(f"Launching labeling interface for {len(df)} samples")
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                    filepaths = df["filepath"].tolist()
         
     | 
| 
      
 154 
     | 
    
         
            +
             
     | 
| 
      
 155 
     | 
    
         
            +
                    with gr.Blocks(head=shortcut_js) as demo:
         
     | 
| 
      
 156 
     | 
    
         
            +
                        current_index = gr.State(value=0)
         
     | 
| 
      
 157 
     | 
    
         
            +
             
     | 
| 
      
 158 
     | 
    
         
            +
                        filename = gr.Textbox(
         
     | 
| 
      
 159 
     | 
    
         
            +
                            label="Filename", value=filepaths[0], interactive=False
         
     | 
| 
      
 160 
     | 
    
         
            +
                        )
         
     | 
| 
      
 161 
     | 
    
         
            +
             
     | 
| 
      
 162 
     | 
    
         
            +
                        image = gr.Image(
         
     | 
| 
      
 163 
     | 
    
         
            +
                            type="filepath", label="Image", value=filepaths[0], height=500
         
     | 
| 
      
 164 
     | 
    
         
            +
                        )
         
     | 
| 
      
 165 
     | 
    
         
            +
                        category = gr.Radio(choices=self.class_names, label="Select Category")
         
     | 
| 
      
 166 
     | 
    
         
            +
             
     | 
| 
      
 167 
     | 
    
         
            +
                        with gr.Row():
         
     | 
| 
      
 168 
     | 
    
         
            +
                            back_btn = gr.Button("← Previous (A)", elem_id="back_btn")
         
     | 
| 
      
 169 
     | 
    
         
            +
                            submit_btn = gr.Button(
         
     | 
| 
      
 170 
     | 
    
         
            +
                                "Submit (W)",
         
     | 
| 
      
 171 
     | 
    
         
            +
                                variant="primary",
         
     | 
| 
      
 172 
     | 
    
         
            +
                                elem_id="submit_btn",
         
     | 
| 
      
 173 
     | 
    
         
            +
                                interactive=False,
         
     | 
| 
      
 174 
     | 
    
         
            +
                            )
         
     | 
| 
      
 175 
     | 
    
         
            +
                            next_btn = gr.Button("Next → (D)", elem_id="next_btn")
         
     | 
| 
      
 176 
     | 
    
         
            +
             
     | 
| 
      
 177 
     | 
    
         
            +
                        progress = gr.Slider(
         
     | 
| 
      
 178 
     | 
    
         
            +
                            minimum=0,
         
     | 
| 
      
 179 
     | 
    
         
            +
                            maximum=len(filepaths) - 1,
         
     | 
| 
      
 180 
     | 
    
         
            +
                            value=0,
         
     | 
| 
      
 181 
     | 
    
         
            +
                            label="Progress",
         
     | 
| 
      
 182 
     | 
    
         
            +
                            interactive=False,
         
     | 
| 
      
 183 
     | 
    
         
            +
                        )
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
                        finish_btn = gr.Button("Finish Labeling", variant="primary")
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
                        def update_submit_btn(choice):
         
     | 
| 
      
 188 
     | 
    
         
            +
                            return gr.Button(interactive=choice is not None)
         
     | 
| 
      
 189 
     | 
    
         
            +
             
     | 
| 
      
 190 
     | 
    
         
            +
                        category.change(
         
     | 
| 
      
 191 
     | 
    
         
            +
                            fn=update_submit_btn, inputs=[category], outputs=[submit_btn]
         
     | 
| 
      
 192 
     | 
    
         
            +
                        )
         
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
                        def navigate(current_idx, direction):
         
     | 
| 
      
 195 
     | 
    
         
            +
                            next_idx = current_idx + direction
         
     | 
| 
      
 196 
     | 
    
         
            +
                            if 0 <= next_idx < len(filepaths):
         
     | 
| 
      
 197 
     | 
    
         
            +
                                return filepaths[next_idx], filepaths[next_idx], next_idx, next_idx
         
     | 
| 
      
 198 
     | 
    
         
            +
                            return (
         
     | 
| 
      
 199 
     | 
    
         
            +
                                filepaths[current_idx],
         
     | 
| 
      
 200 
     | 
    
         
            +
                                filepaths[current_idx],
         
     | 
| 
      
 201 
     | 
    
         
            +
                                current_idx,
         
     | 
| 
      
 202 
     | 
    
         
            +
                                current_idx,
         
     | 
| 
      
 203 
     | 
    
         
            +
                            )
         
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
      
 205 
     | 
    
         
            +
                        def save_and_next(current_idx, selected_category):
         
     | 
| 
      
 206 
     | 
    
         
            +
                            if selected_category is None:
         
     | 
| 
      
 207 
     | 
    
         
            +
                                return (
         
     | 
| 
      
 208 
     | 
    
         
            +
                                    filepaths[current_idx],
         
     | 
| 
      
 209 
     | 
    
         
            +
                                    filepaths[current_idx],
         
     | 
| 
      
 210 
     | 
    
         
            +
                                    current_idx,
         
     | 
| 
      
 211 
     | 
    
         
            +
                                    current_idx,
         
     | 
| 
      
 212 
     | 
    
         
            +
                                )
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
                            # Save the current annotation
         
     | 
| 
      
 215 
     | 
    
         
            +
                            with open(f"{output_filename}.csv", "a") as f:
         
     | 
| 
      
 216 
     | 
    
         
            +
                                f.write(f"{filepaths[current_idx]},{selected_category}\n")
         
     | 
| 
      
 217 
     | 
    
         
            +
             
     | 
| 
      
 218 
     | 
    
         
            +
                            # Move to next image if not at the end
         
     | 
| 
      
 219 
     | 
    
         
            +
                            next_idx = current_idx + 1
         
     | 
| 
      
 220 
     | 
    
         
            +
                            if next_idx >= len(filepaths):
         
     | 
| 
      
 221 
     | 
    
         
            +
                                return (
         
     | 
| 
      
 222 
     | 
    
         
            +
                                    filepaths[current_idx],
         
     | 
| 
      
 223 
     | 
    
         
            +
                                    filepaths[current_idx],
         
     | 
| 
      
 224 
     | 
    
         
            +
                                    current_idx,
         
     | 
| 
      
 225 
     | 
    
         
            +
                                    current_idx,
         
     | 
| 
      
 226 
     | 
    
         
            +
                                )
         
     | 
| 
      
 227 
     | 
    
         
            +
                            return filepaths[next_idx], filepaths[next_idx], next_idx, next_idx
         
     | 
| 
      
 228 
     | 
    
         
            +
             
     | 
| 
      
 229 
     | 
    
         
            +
                        def convert_csv_to_parquet():
         
     | 
| 
      
 230 
     | 
    
         
            +
                            try:
         
     | 
| 
      
 231 
     | 
    
         
            +
                                df = pd.read_csv(f"{output_filename}.csv", header=None)
         
     | 
| 
      
 232 
     | 
    
         
            +
                                df.columns = ["filepath", "label"]
         
     | 
| 
      
 233 
     | 
    
         
            +
                                df = df.drop_duplicates(subset=["filepath"], keep="last")
         
     | 
| 
      
 234 
     | 
    
         
            +
                                df.to_parquet(f"{output_filename}.parquet")
         
     | 
| 
      
 235 
     | 
    
         
            +
                                gr.Info(f"Annotation saved to {output_filename}.parquet")
         
     | 
| 
      
 236 
     | 
    
         
            +
                            except Exception as e:
         
     | 
| 
      
 237 
     | 
    
         
            +
                                logger.error(e)
         
     | 
| 
      
 238 
     | 
    
         
            +
                                return
         
     | 
| 
      
 239 
     | 
    
         
            +
             
     | 
| 
      
 240 
     | 
    
         
            +
                        back_btn.click(
         
     | 
| 
      
 241 
     | 
    
         
            +
                            fn=lambda idx: navigate(idx, -1),
         
     | 
| 
      
 242 
     | 
    
         
            +
                            inputs=[current_index],
         
     | 
| 
      
 243 
     | 
    
         
            +
                            outputs=[filename, image, current_index, progress],
         
     | 
| 
      
 244 
     | 
    
         
            +
                        )
         
     | 
| 
      
 245 
     | 
    
         
            +
             
     | 
| 
      
 246 
     | 
    
         
            +
                        next_btn.click(
         
     | 
| 
      
 247 
     | 
    
         
            +
                            fn=lambda idx: navigate(idx, 1),
         
     | 
| 
      
 248 
     | 
    
         
            +
                            inputs=[current_index],
         
     | 
| 
      
 249 
     | 
    
         
            +
                            outputs=[filename, image, current_index, progress],
         
     | 
| 
      
 250 
     | 
    
         
            +
                        )
         
     | 
| 
      
 251 
     | 
    
         
            +
             
     | 
| 
      
 252 
     | 
    
         
            +
                        submit_btn.click(
         
     | 
| 
      
 253 
     | 
    
         
            +
                            fn=save_and_next,
         
     | 
| 
      
 254 
     | 
    
         
            +
                            inputs=[current_index, category],
         
     | 
| 
      
 255 
     | 
    
         
            +
                            outputs=[filename, image, current_index, progress],
         
     | 
| 
      
 256 
     | 
    
         
            +
                        )
         
     | 
| 
      
 257 
     | 
    
         
            +
             
     | 
| 
      
 258 
     | 
    
         
            +
                        finish_btn.click(fn=convert_csv_to_parquet)
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
                    demo.launch(height=1000)
         
     | 
| 
      
 261 
     | 
    
         
            +
             
     | 
| 
      
 262 
     | 
    
         
            +
                def add_to_train_set(self, df: pd.DataFrame, output_filename: str):
         
     | 
| 
      
 263 
     | 
    
         
            +
                    """
         
     | 
| 
      
 264 
     | 
    
         
            +
                    Add samples to the training set.
         
     | 
| 
      
 265 
     | 
    
         
            +
                    """
         
     | 
| 
      
 266 
     | 
    
         
            +
                    new_train_set = df.copy()
         
     | 
| 
      
 267 
     | 
    
         
            +
                    # new_train_set.drop(columns=["pred_conf"], inplace=True)
         
     | 
| 
      
 268 
     | 
    
         
            +
                    # new_train_set.rename(columns={"pred_label": "label"}, inplace=True)
         
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
                    # len_old = len(self.train_set)
         
     | 
| 
      
 271 
     | 
    
         
            +
             
     | 
| 
      
 272 
     | 
    
         
            +
                    logger.info(f"Adding {len(new_train_set)} samples to training set")
         
     | 
| 
      
 273 
     | 
    
         
            +
                    self.train_set = pd.concat([self.train_set, new_train_set])
         
     | 
| 
      
 274 
     | 
    
         
            +
             
     | 
| 
      
 275 
     | 
    
         
            +
                    self.train_set = self.train_set.drop_duplicates(
         
     | 
| 
      
 276 
     | 
    
         
            +
                        subset=["filepath"], keep="last"
         
     | 
| 
      
 277 
     | 
    
         
            +
                    )
         
     | 
| 
      
 278 
     | 
    
         
            +
                    self.train_set.reset_index(drop=True, inplace=True)
         
     | 
| 
      
 279 
     | 
    
         
            +
             
     | 
| 
      
 280 
     | 
    
         
            +
                    self.train_set.to_parquet(f"{output_filename}.parquet")
         
     | 
| 
      
 281 
     | 
    
         
            +
                    logger.info(f"Saved training set to {output_filename}.parquet")
         
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
                    # if len(self.train_set) == len_old:
         
     | 
| 
      
 284 
     | 
    
         
            +
                    #     logger.warning("No new samples added to training set")
         
     | 
| 
      
 285 
     | 
    
         
            +
             
     | 
| 
      
 286 
     | 
    
         
            +
                    # elif len_old + len(new_train_set) < len(self.train_set):
         
     | 
| 
      
 287 
     | 
    
         
            +
                    #     logger.warning("Some samples were duplicates and removed from training set")
         
     | 
| 
      
 288 
     | 
    
         
            +
             
     | 
| 
      
 289 
     | 
    
         
            +
                    # else:
         
     | 
| 
      
 290 
     | 
    
         
            +
                    #     logger.info("All new samples added to training set")
         
     | 
| 
      
 291 
     | 
    
         
            +
                    #     logger.info(f"Training set now has {len(self.train_set)} samples")
         
     | 
| 
         @@ -1,12 +1,13 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            Metadata-Version: 2.2
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: active-vision
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.0. 
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.0.3
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: Active learning for edge vision.
         
     | 
| 
       5 
5 
     | 
    
         
             
            Requires-Python: >=3.10
         
     | 
| 
       6 
6 
     | 
    
         
             
            Description-Content-Type: text/markdown
         
     | 
| 
       7 
7 
     | 
    
         
             
            License-File: LICENSE
         
     | 
| 
       8 
8 
     | 
    
         
             
            Requires-Dist: datasets>=3.2.0
         
     | 
| 
       9 
9 
     | 
    
         
             
            Requires-Dist: fastai>=2.7.18
         
     | 
| 
      
 10 
     | 
    
         
            +
            Requires-Dist: gradio>=5.12.0
         
     | 
| 
       10 
11 
     | 
    
         
             
            Requires-Dist: ipykernel>=6.29.5
         
     | 
| 
       11 
12 
     | 
    
         
             
            Requires-Dist: ipywidgets>=8.1.5
         
     | 
| 
       12 
13 
     | 
    
         
             
            Requires-Dist: loguru>=0.7.3
         
     | 
| 
         @@ -14,75 +15,93 @@ Requires-Dist: seaborn>=0.13.2 
     | 
|
| 
       14 
15 
     | 
    
         | 
| 
       15 
16 
     | 
    
         
             
            
         
     | 
| 
       16 
17 
     | 
    
         
             
            
         
     | 
| 
      
 18 
     | 
    
         
            +
            [](https://pypi.org/project/active-vision/)
         
     | 
| 
      
 19 
     | 
    
         
            +
               
         
     | 
| 
       17 
20 
     | 
    
         | 
| 
       18 
21 
     | 
    
         
             
            <p align="center">
         
     | 
| 
       19 
     | 
    
         
            -
              <img src=" 
     | 
| 
      
 22 
     | 
    
         
            +
              <img src="https://raw.githubusercontent.com/dnth/active-vision/main/assets/logo.png" alt="active-vision">
         
     | 
| 
       20 
23 
     | 
    
         
             
            </p>
         
     | 
| 
       21 
24 
     | 
    
         | 
| 
       22 
25 
     | 
    
         
             
            Active learning at the edge for computer vision.
         
     | 
| 
       23 
26 
     | 
    
         | 
| 
       24 
     | 
    
         
            -
            The goal of this project is to create a framework for active learning  
     | 
| 
      
 27 
     | 
    
         
            +
            The goal of this project is to create a framework for the active learning loop for computer vision deployed on edge devices. 
         
     | 
| 
       25 
28 
     | 
    
         | 
| 
       26 
     | 
    
         
            -
            ##  
     | 
| 
      
 29 
     | 
    
         
            +
            ## Installation
         
     | 
| 
      
 30 
     | 
    
         
            +
            I recommend using [uv](https://docs.astral.sh/uv/) to set up a virtual environment and install the package. You can also use other virtual env of your choice.
         
     | 
| 
       27 
31 
     | 
    
         | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
            - User interface: streamlit
         
     | 
| 
       30 
     | 
    
         
            -
            - Database: sqlite
         
     | 
| 
       31 
     | 
    
         
            -
            - Experiment tracking: wandb
         
     | 
| 
      
 32 
     | 
    
         
            +
            If you're using uv:
         
     | 
| 
       32 
33 
     | 
    
         | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
      
 34 
     | 
    
         
            +
            ```bash
         
     | 
| 
      
 35 
     | 
    
         
            +
            uv venv
         
     | 
| 
      
 36 
     | 
    
         
            +
            uv sync
         
     | 
| 
      
 37 
     | 
    
         
            +
            ```
         
     | 
| 
      
 38 
     | 
    
         
            +
            Once the virtual environment is created, you can install the package using pip.
         
     | 
| 
       34 
39 
     | 
    
         | 
| 
       35 
     | 
    
         
            -
            PyPI
         
     | 
| 
      
 40 
     | 
    
         
            +
            Get a release from PyPI
         
     | 
| 
       36 
41 
     | 
    
         
             
            ```bash
         
     | 
| 
       37 
42 
     | 
    
         
             
            pip install active-vision
         
     | 
| 
       38 
43 
     | 
    
         
             
            ```
         
     | 
| 
       39 
44 
     | 
    
         | 
| 
       40 
     | 
    
         
            -
             
     | 
| 
      
 45 
     | 
    
         
            +
            Install from source
         
     | 
| 
       41 
46 
     | 
    
         
             
            ```bash
         
     | 
| 
       42 
47 
     | 
    
         
             
            git clone https://github.com/dnth/active-vision.git
         
     | 
| 
       43 
48 
     | 
    
         
             
            cd active-vision
         
     | 
| 
       44 
49 
     | 
    
         
             
            pip install -e .
         
     | 
| 
       45 
50 
     | 
    
         
             
            ```
         
     | 
| 
       46 
51 
     | 
    
         | 
| 
       47 
     | 
    
         
            -
             
     | 
| 
      
 52 
     | 
    
         
            +
            > [!TIP]
         
     | 
| 
      
 53 
     | 
    
         
            +
            > If you're using uv add a uv before the pip install command to install into your virtual environment. Eg:
         
     | 
| 
      
 54 
     | 
    
         
            +
            > ```bash
         
     | 
| 
      
 55 
     | 
    
         
            +
            > uv pip install active-vision
         
     | 
| 
      
 56 
     | 
    
         
            +
            > ```
         
     | 
| 
       48 
57 
     | 
    
         | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
      
 58 
     | 
    
         
            +
            ## Usage
         
     | 
| 
      
 59 
     | 
    
         
            +
            See the [notebook](./nbs/04_relabel_loop.ipynb) for a complete example.
         
     | 
| 
       51 
60 
     | 
    
         | 
| 
       52 
     | 
    
         
            -
             
     | 
| 
       53 
     | 
    
         
            -
             
     | 
| 
      
 61 
     | 
    
         
            +
            Be sure to prepared 3 datasets:
         
     | 
| 
      
 62 
     | 
    
         
            +
            - train: A dataframe of an existing labeled training dataset.
         
     | 
| 
      
 63 
     | 
    
         
            +
            - unlabeled: A dataframe of unlabeled data which we will sample from using active learning.
         
     | 
| 
      
 64 
     | 
    
         
            +
            - eval: A dataframe of labeled data which we will use to evaluate the performance of the model. (Optional)
         
     | 
| 
       54 
65 
     | 
    
         | 
| 
       55 
     | 
    
         
            -
             
     | 
| 
       56 
     | 
    
         
            -
             
     | 
| 
      
 66 
     | 
    
         
            +
            ```python
         
     | 
| 
      
 67 
     | 
    
         
            +
            from active_vision import ActiveLearner
         
     | 
| 
      
 68 
     | 
    
         
            +
            import pandas as pd
         
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
            # Create an active learner instance with a model
         
     | 
| 
      
 71 
     | 
    
         
            +
            al = ActiveLearner("resnet18")
         
     | 
| 
       57 
72 
     | 
    
         | 
| 
       58 
     | 
    
         
            -
            #  
     | 
| 
       59 
     | 
    
         
            -
             
     | 
| 
      
 73 
     | 
    
         
            +
            # Load dataset 
         
     | 
| 
      
 74 
     | 
    
         
            +
            train_df = pd.read_parquet("training_samples.parquet")
         
     | 
| 
      
 75 
     | 
    
         
            +
            al.load_dataset(df, filepath_col="filepath", label_col="label")
         
     | 
| 
       60 
76 
     | 
    
         | 
| 
       61 
     | 
    
         
            -
            # Train  
     | 
| 
       62 
     | 
    
         
            -
             
     | 
| 
      
 77 
     | 
    
         
            +
            # Train model
         
     | 
| 
      
 78 
     | 
    
         
            +
            al.train(epochs=3, lr=1e-3)
         
     | 
| 
       63 
79 
     | 
    
         | 
| 
       64 
     | 
    
         
            -
            #  
     | 
| 
       65 
     | 
    
         
            -
             
     | 
| 
      
 80 
     | 
    
         
            +
            # Evaluate the model on a *labeled* evaluation set
         
     | 
| 
      
 81 
     | 
    
         
            +
            accuracy = al.evaluate(eval_df, filepath_col="filepath", label_col="label")
         
     | 
| 
       66 
82 
     | 
    
         | 
| 
       67 
     | 
    
         
            -
            #  
     | 
| 
       68 
     | 
    
         
            -
             
     | 
| 
      
 83 
     | 
    
         
            +
            # Get predictions from an *unlabeled* set
         
     | 
| 
      
 84 
     | 
    
         
            +
            pred_df = al.predict(filepaths)
         
     | 
| 
       69 
85 
     | 
    
         | 
| 
       70 
     | 
    
         
            -
            #  
     | 
| 
       71 
     | 
    
         
            -
             
     | 
| 
      
 86 
     | 
    
         
            +
            # Sample low confidence predictions from unlabeled set
         
     | 
| 
      
 87 
     | 
    
         
            +
            uncertain_df = al.sample_uncertain(pred_df, num_samples=10)
         
     | 
| 
       72 
88 
     | 
    
         | 
| 
       73 
     | 
    
         
            -
            #  
     | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
      
 89 
     | 
    
         
            +
            # Launch a Gradio UI to label the low confidence samples
         
     | 
| 
      
 90 
     | 
    
         
            +
            al.label(uncertain_df, output_filename="uncertain")
         
     | 
| 
      
 91 
     | 
    
         
            +
            ```
         
     | 
| 
       75 
92 
     | 
    
         | 
| 
       76 
     | 
    
         
            -
             
     | 
| 
       77 
     | 
    
         
            -
            model.random_sampling()
         
     | 
| 
      
 93 
     | 
    
         
            +
            
         
     | 
| 
       78 
94 
     | 
    
         | 
| 
       79 
     | 
    
         
            -
             
     | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
      
 95 
     | 
    
         
            +
            Once complete, the labeled samples will be save into a new df.
         
     | 
| 
      
 96 
     | 
    
         
            +
            We can now add the newly labeled data to the training set.
         
     | 
| 
       81 
97 
     | 
    
         | 
| 
       82 
     | 
    
         
            -
             
     | 
| 
       83 
     | 
    
         
            -
             
     | 
| 
      
 98 
     | 
    
         
            +
            ```python
         
     | 
| 
      
 99 
     | 
    
         
            +
            # Add newly labeled data to training set and save as a new file active_labeled
         
     | 
| 
      
 100 
     | 
    
         
            +
            al.add_to_train_set(labeled_df, output_filename="active_labeled")
         
     | 
| 
       84 
101 
     | 
    
         
             
            ```
         
     | 
| 
       85 
102 
     | 
    
         | 
| 
      
 103 
     | 
    
         
            +
            Repeat the process until the model is good enough. Use the dataset to train a larger model and deploy.
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
       86 
105 
     | 
    
         
             
            ## Workflow
         
     | 
| 
       87 
106 
     | 
    
         
             
            There are two workflows for active learning at the edge that we can use depending on the availability of labeled data.
         
     | 
| 
       88 
107 
     | 
    
         | 
| 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            active_vision/__init__.py,sha256=hZp8jB284ByY44Q5cdwTt9Zz5n4QWXnz0OexpEE9muk,43
         
     | 
| 
      
 2 
     | 
    
         
            +
            active_vision/core.py,sha256=0aXDI5Gpj0Spk7TSIxJf8aJDDBgZh0-jkWdYyZ1Zric,10713
         
     | 
| 
      
 3 
     | 
    
         
            +
            active_vision-0.0.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
         
     | 
| 
      
 4 
     | 
    
         
            +
            active_vision-0.0.3.dist-info/METADATA,sha256=g629Kn07n4ZXOOX5cW1nPQK1IR9Mm5vW_z7RqxdwKgY,9385
         
     | 
| 
      
 5 
     | 
    
         
            +
            active_vision-0.0.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
         
     | 
| 
      
 6 
     | 
    
         
            +
            active_vision-0.0.3.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
         
     | 
| 
      
 7 
     | 
    
         
            +
            active_vision-0.0.3.dist-info/RECORD,,
         
     | 
| 
         @@ -1,6 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            active_vision/__init__.py,sha256=sXLh7g3KC4QCFxcZGBTpG2scR7hmmBsMjq6LqRptkRg,22
         
     | 
| 
       2 
     | 
    
         
            -
            active_vision-0.0.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
         
     | 
| 
       3 
     | 
    
         
            -
            active_vision-0.0.1.dist-info/METADATA,sha256=lPOTTVSPAaX3Rn9Q1ci_jgoQOC-HFpQIyTNqrouOYEs,7936
         
     | 
| 
       4 
     | 
    
         
            -
            active_vision-0.0.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
         
     | 
| 
       5 
     | 
    
         
            -
            active_vision-0.0.1.dist-info/top_level.txt,sha256=7qUQvccN2UU63z5S9vrgJmqK-8sFGrtpf1e9Z86nihE,14
         
     | 
| 
       6 
     | 
    
         
            -
            active_vision-0.0.1.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |