PictSure 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of PictSure might be problematic. Click here for more details.

Examples/example.py ADDED
@@ -0,0 +1,72 @@
1
+ import os
2
+ import random
3
+ from PIL import Image
4
+ import torch
5
+ from PictSure import PictSure
6
+
7
+ # CONFIG
8
+ ROOT_DIR = "./BrainTumor_preprocessed/"
9
+ NUM_CONTEXT_IMAGES = 5
10
+ IMAGE_SIZE = 224
11
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
12
+
13
+ # Load context/reference images
14
+ def load_reference_images(path):
15
+ label_map = {}
16
+ context_images, context_labels = [], []
17
+
18
+ folders = sorted(os.listdir(path))
19
+ for label, folder in enumerate(folders):
20
+ folder_path = os.path.join(path, folder)
21
+ all_images = os.listdir(folder_path)
22
+ chosen = random.sample(all_images, NUM_CONTEXT_IMAGES + 1) # +1 for extra test image
23
+ ref_imgs = chosen[:-1]
24
+ test_img = chosen[-1]
25
+
26
+ for img_name in ref_imgs:
27
+ img_path = os.path.join(folder_path, img_name)
28
+ img = Image.open(img_path).convert("RGB")
29
+ context_images.append(img)
30
+ context_labels.append(label)
31
+
32
+ label_map[folder] = label
33
+
34
+ return context_images, context_labels, label_map, chosen
35
+
36
+ # Pick a single test image (one left out per class)
37
+ def pick_test_image(path, label_map, chosen):
38
+ all_images = []
39
+ all_labels = []
40
+
41
+ for folder, label in label_map.items():
42
+ folder_path = os.path.join(path, folder)
43
+ images = [f for f in os.listdir(folder_path) if f not in chosen]
44
+ for img_name in images:
45
+ img_path = os.path.join(folder_path, img_name)
46
+ all_images.append(img_path)
47
+ all_labels.append(label)
48
+
49
+ if all_images:
50
+ random_index = random.randint(0, len(all_images) - 1)
51
+ img_path = all_images[random_index]
52
+ label = all_labels[random_index]
53
+ img = Image.open(img_path).convert("RGB")
54
+ return img, label
55
+
56
+ # or pull our pre-trained models from HuggingFace
57
+ pictsure_model = PictSure.from_pretrained("pictsure/pictsure-vit").to(DEVICE)
58
+
59
+ results = []
60
+ for i in range(200):
61
+ # Load references and test image
62
+ context_imgs, context_lbls, label_map, chosen = load_reference_images(ROOT_DIR)
63
+ test_img, test_lbl = pick_test_image(ROOT_DIR, label_map, chosen)
64
+ # Predict
65
+ with torch.no_grad():
66
+ pictsure_model.set_context_images(context_imgs, context_lbls)
67
+ pred = pictsure_model.predict(test_img)
68
+
69
+ results.append((pred == test_lbl))
70
+
71
+ accuracy = sum(results) / len(results) * 100
72
+ print(f"Accuracy over {len(results)} predictions: {accuracy:.1f}%")
PictSure/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .model_PictSure import PictSure
2
+ from .model_embeddings import ResNetWrapper, VitNetWrapper
3
+
4
+ __all__ = ['PictSure', 'ResNetWrapper', 'VitNetWrapper']
PictSure/cli.py ADDED
@@ -0,0 +1,84 @@
1
+ import os
2
+ import click
3
+ import shutil
4
+ from .config import PRETRAINED
5
+ import pkg_resources
6
+
7
+ @click.group()
8
+ def cli():
9
+ """PictSure command line interface."""
10
+ pass
11
+
12
+ @cli.command()
13
+ def list_models():
14
+ """List all available and downloaded models."""
15
+ click.echo("Available PictSure Models:")
16
+ click.echo("=" * 50)
17
+
18
+ # Get the package directory
19
+ package_dir = os.path.dirname(pkg_resources.resource_filename('PictSure', '__init__.py'))
20
+
21
+ for model_name, model_info in PRETRAINED.items():
22
+ # Check if model is downloaded using absolute path
23
+ local_folder = os.path.join(package_dir, 'weights', model_info['name'])
24
+ weights_path = os.path.join(local_folder, 'weights.pt')
25
+ is_downloaded = os.path.exists(weights_path)
26
+
27
+ # Create status indicator
28
+ status = "✓ Downloaded" if is_downloaded else "✗ Not downloaded"
29
+ status_color = "green" if is_downloaded else "red"
30
+
31
+ # Print model information
32
+ click.echo(f"\nModel: {click.style(model_info['name'], bold=True)}")
33
+ click.echo(f"Status: {click.style(status, fg=status_color)}")
34
+ click.echo(f"Type: {model_info['embed_model']}")
35
+ click.echo(f"Resolution: {model_info['resolution']}")
36
+ click.echo(f"Number of classes: {model_info['num_classes']}")
37
+ click.echo(f"Transformer heads: {model_info['nheads']}")
38
+ click.echo(f"Transformer layers: {model_info['nlayer']}")
39
+ click.echo(f"Model size: {model_info['size']} Million Parameters")
40
+ click.echo(f"Path: {weights_path}")
41
+ click.echo("-" * 50)
42
+
43
+ @cli.command()
44
+ @click.argument('model_name', type=click.Choice([info['name'] for info in PRETRAINED.values()]))
45
+ @click.option('--force', '-f', is_flag=True, help='Skip confirmation prompt')
46
+ def remove(model_name, force):
47
+ """Remove the weights of a specific model."""
48
+ # Get the package directory
49
+ package_dir = os.path.dirname(pkg_resources.resource_filename('PictSure', '__init__.py'))
50
+
51
+ # Find the model info by name
52
+ model_info = next((info for info in PRETRAINED.values() if info['name'] == model_name), None)
53
+ if not model_info:
54
+ click.echo(click.style(f"Model {model_name} not found.", fg='red'))
55
+ return
56
+
57
+ # Construct paths
58
+ local_folder = os.path.join(package_dir, 'weights', model_info['name'])
59
+ weights_path = os.path.join(local_folder, 'weights.pt')
60
+
61
+ if not os.path.exists(weights_path):
62
+ click.echo(click.style(f"Model {model_info['name']} is not downloaded.", fg='yellow'))
63
+ return
64
+
65
+ if not force:
66
+ if not click.confirm(f"Are you sure you want to remove the weights for {click.style(model_info['name'], bold=True)}?"):
67
+ click.echo("Operation cancelled.")
68
+ return
69
+
70
+ try:
71
+ # Remove the weights file
72
+ os.remove(weights_path)
73
+ # Try to remove the directory if it's empty
74
+ try:
75
+ os.rmdir(local_folder)
76
+ except OSError:
77
+ pass # Directory might not be empty, which is fine
78
+
79
+ click.echo(click.style(f"Successfully removed weights for {model_info['name']}.", fg='green'))
80
+ except Exception as e:
81
+ click.echo(click.style(f"Error removing weights: {str(e)}", fg='red'))
82
+
83
+ if __name__ == '__main__':
84
+ cli()
@@ -0,0 +1,240 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import datasets, transforms
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision import models
11
+
12
+ from .model_embeddings import ResNetWrapper, VitNetWrapper, load_encoder
13
+ from huggingface_hub import PyTorchModelHubMixin
14
+
15
+ from PIL import Image
16
+
17
+ class PictSure(
18
+ nn.Module,
19
+ PyTorchModelHubMixin
20
+ ):
21
+ def __init__(self, embedding, num_classes=10, nheads=8, nlayer=4):
22
+ super(PictSure, self).__init__()
23
+ if isinstance(embedding, nn.Module):
24
+ embedding_layer = embedding
25
+ if not hasattr(embedding_layer, 'latent_dim'):
26
+ raise ValueError("Custom embedding module must have a 'latent_dim' attribute.")
27
+ elif embedding == 'resnet':
28
+ embedding_layer = load_encoder()
29
+ elif embedding == 'vit':
30
+ embedding_layer = VitNetWrapper(path=None, num_classes=1000)
31
+ else:
32
+ raise ValueError("Unsupported embedding type. Use 'resnet' or 'vit' or custom nn.Modul.")
33
+
34
+ self.x_projection = nn.Linear(embedding_layer.latent_dim, 512)
35
+ self.y_projection = nn.Linear(num_classes, 512)
36
+
37
+ self.transformer_layer = nn.TransformerEncoderLayer(
38
+ d_model=1024, nhead=nheads, dim_feedforward=2048, norm_first=True
39
+ )
40
+ self.transformer = nn.TransformerEncoder(self.transformer_layer, num_layers=nlayer)
41
+ self.fc = nn.Linear(1024, num_classes)
42
+ self._init_weights()
43
+
44
+ self.num_classes = num_classes
45
+
46
+ self.embedding = embedding_layer
47
+
48
+ self.context_images = None
49
+ self.context_labels = None
50
+
51
+ def to(self, device):
52
+ self.embedding = self.embedding.to(device)
53
+ self.x_projection = self.x_projection.to(device)
54
+ self.y_projection = self.y_projection.to(device)
55
+ self.transformer = self.transformer.to(device)
56
+ self.fc = self.fc.to(device)
57
+ return self
58
+
59
+ @property
60
+ def device(self):
61
+ return self.embedding.device
62
+
63
+ def _init_weights(self):
64
+ # Loop through all modules in the model
65
+ for name, param in self.named_parameters():
66
+ if 'weight' in name:
67
+ if param.dim() > 1: # Apply Xavier only to 2D+ parameters
68
+ nn.init.xavier_uniform_(param)
69
+ elif 'bias' in name:
70
+ nn.init.zeros_(param) # Bias is initialized to zero
71
+
72
+ def normalize_samples(self, x, resize=(224, 224)):
73
+ """
74
+ Normalize and resize the input images.
75
+ :param x: Tensor of shape (batch, num_images, 3, 224, 224)
76
+ :param resize: Tuple for resizing images
77
+ :return: Normalized and resized images
78
+ """
79
+
80
+ original_shape = x.shape
81
+ if len(original_shape) == 5:
82
+ # Reshape to (batch * num_images, 3, 224, 224)
83
+ x = x.view(-1, 3, 224, 224)
84
+ elif len(original_shape) == 3:
85
+ x = x.unsqueeze(0) # Add batch dimension if missing
86
+
87
+ # Rescale images to the specified size
88
+ if resize is not None:
89
+ x = F.interpolate(x, size=resize, mode='bilinear', align_corners=False)
90
+
91
+ # Normalize images to [0, 1] range
92
+ if x.max() > 1.0:
93
+ x = x / 255.0
94
+ # Check if the input is already normalized with the specified mean and std
95
+ mean = torch.tensor([0.4914, 0.4822, 0.4465], device=x.device).view(1, 3, 1, 1)
96
+ std = torch.tensor([0.2023, 0.1994, 0.2010], device=x.device).view(1, 3, 1, 1)
97
+
98
+ # print(f"Range of x before normalization: {x.min().item()} to {x.max().item()}")
99
+
100
+ x = (x - mean) / std
101
+
102
+ # print(f"Range of x after normalization: {x.min().item()} to {x.max().item()}")
103
+ # print("\n")
104
+ # Reshape back to (batch, num_images, 3, 224, 224)
105
+ if len(original_shape) == 5:
106
+ x = x.view(original_shape[0], original_shape[1], 3, resize[0], resize[1])
107
+ elif len(original_shape) == 3:
108
+ x = x.squeeze(0)
109
+
110
+ return x
111
+
112
+ def set_context_images(self, context_images, context_labels):
113
+ """
114
+ Set the context images and labels for the model.
115
+ :param context_images: Tensor of shape (1, num_images, 3, 224, 224)
116
+ :param context_labels: Tensor of shape (1, num_images)
117
+ """
118
+ if isinstance(context_images, list) and all(isinstance(img, Image.Image) for img in context_images):
119
+ # Convert list of PIL images to tensor
120
+ context_images = np.stack([np.array(img.resize((224, 224))) for img in context_images])
121
+ context_images = torch.tensor(context_images, dtype=torch.float32)
122
+ context_images = context_images.view(1, -1, 3, 224, 224) # Ensure it has the right shape
123
+ if isinstance(context_labels, list):
124
+ context_labels = torch.tensor(context_labels, dtype=torch.int64)
125
+ context_labels = context_labels.unsqueeze(0) # Shape: (1, num_images)
126
+
127
+ if context_images.ndim == 4:
128
+ context_images = context_images.unsqueeze(0)
129
+
130
+ # print(f"Min and max of context_images before normalization: {context_images.min().item()} to {context_images.max().item()}")
131
+ assert context_images.ndim == 5, "context_images must be of shape (1, num_images, 3, 224, 224)"
132
+ assert context_labels.ndim == 2, "context_labels must be of shape (1, num_images)"
133
+
134
+ context_images = self.normalize_samples(context_images, resize=(224, 224))
135
+
136
+ self.context_images = context_images
137
+ self.context_labels = context_labels
138
+
139
+ def predict(self, x_pred):
140
+ """
141
+ Predict the class logits for the given prediction images.
142
+ :param x_pred: Tensor of shape (batch, num_images, 3, 224, 224)
143
+ :return: Logits of shape (batch, num_classes)
144
+ """
145
+ if self.context_images is None or self.context_labels is None:
146
+ raise ValueError("Context images and labels must be set before prediction.")
147
+
148
+ if isinstance(x_pred, list) and all(isinstance(img, Image.Image) for img in x_pred):
149
+ # Convert list of PIL images to tensor
150
+ x_pred = np.stack([np.array(img.resize((224, 224))) for img in x_pred])
151
+ x_pred = torch.tensor(x_pred, dtype=torch.float32)
152
+ x_pred = x_pred.view(-1, 3, 224, 224) # Ensure it has the right shape
153
+ x_pred = x_pred / 255.0 # Normalize to [0, 1] range
154
+ if isinstance(x_pred, Image.Image):
155
+ # Convert single PIL image to tensor
156
+ x_pred = np.array(x_pred.resize((224, 224)))
157
+ x_pred = torch.tensor(x_pred, dtype=torch.float32).unsqueeze(0)
158
+ x_pred = x_pred.view(1, 3, 224, 224) # Ensure it has the right shape
159
+ x_pred = x_pred / 255.0 # Normalize to [0, 1] range
160
+
161
+ # Expand reference images and labels to match the batch size
162
+ batch_size = x_pred.size(0)
163
+ context_images = self.context_images.expand(batch_size, -1, -1, -1, -1)
164
+ context_labels = self.context_labels.expand(batch_size, -1)
165
+ # Concatenate context images and labels with prediction images
166
+ x_train = context_images.view(batch_size, -1, 3, 224, 224) # Shape: (batch, num_context_images, 3, 224, 224)
167
+ y_train = context_labels.view(batch_size, -1) # Shape: (batch, num_context_images)
168
+
169
+ x_pred = self.normalize_samples(x_pred, resize=(224, 224)) # Normalize prediction images
170
+
171
+ # Move to device
172
+ x_train = x_train.to(self.embedding.device)
173
+ y_train = y_train.to(self.embedding.device)
174
+ x_pred = x_pred.to(self.embedding.device)
175
+
176
+ output = self.forward(x_train, y_train, x_pred, embedd=True)
177
+
178
+ pred = torch.argmax(output, dim=1)
179
+
180
+ return pred.item()
181
+
182
+ def forward(self, x_train, y_train, x_pred, embedd=True):
183
+ if embedd:
184
+ x_embedded = self.embedding(x_train) # Shape: (batch, seq, embedding_dim)
185
+ # (batch, rgb, seq, dim) -> (batch, 1, rgb, seq, dim)
186
+ x_pred = x_pred.unsqueeze(1)
187
+ x_pred_embedded = self.embedding(x_pred) # Shape: (batch, seq, embedding_dim)
188
+ else:
189
+ x_embedded = x_train
190
+ x_pred_embedded = x_pred
191
+
192
+ x_projected = self.x_projection(x_embedded) # Shape: (batch, seq, projection_dim)
193
+
194
+ # Ensure y_train in the right dimensions
195
+ y_train = y_train.unsqueeze(-1) if y_train.ndim == 1 else y_train # Ensure shape (batch, seq, 1)
196
+
197
+ # One-hot encode y_train (batch_size, num_classes * num_images) -> (batch_size, num_images * num_classes, num_classes)
198
+ y_train = F.one_hot(y_train, num_classes=self.num_classes).float()
199
+
200
+ # (batch, seq, num_classes) -> (batch * seq, num_classes)
201
+ y_train = y_train.view(-1, self.num_classes)
202
+
203
+ y_projected = self.y_projection(y_train) # Shape: (batch, seq, projection_dim)
204
+
205
+ # Reshape back to (batch, seq, projection_dim)
206
+ y_projected = y_projected.view(x_projected.size(0), x_projected.size(1), -1)
207
+
208
+ # Concatenate x and y projections
209
+ combined_embedded = torch.cat([x_projected, y_projected], dim=-1) # Shape: (batch, seq, d_model)
210
+
211
+ # Applying the same projection to the prediction
212
+ x_pred_projected = self.x_projection(x_pred_embedded) # Shape: (batch, seq, projection_dim)
213
+
214
+ y_pred_projected = torch.zeros_like(x_pred_projected, device=self.device) -1 # Shape: (batch, seq, projection_dim)
215
+
216
+ # Concatenate x_pred and y_pred projections
217
+ pred_combined_embedded = torch.cat([x_pred_projected, y_pred_projected], dim=-1) # Shape: (batch, seq, d_model)
218
+
219
+ # Concatenate train and prediction embeddings
220
+ full_sequence = torch.cat([combined_embedded, pred_combined_embedded], dim=1) # Shape: (batch, seq+pred_seq, d_model)
221
+
222
+ # (batch, seq, dim -> seq, batch, dim)
223
+ full_sequence = full_sequence.permute(1, 0, 2)
224
+
225
+ # Create an attention mask
226
+ seq_length = full_sequence.size(0)
227
+ attention_mask = torch.ones(seq_length, seq_length, device=self.device)
228
+ attention_mask[-1, :] = 1
229
+ attention_mask[:-1, -1] = 0
230
+ attention_mask = attention_mask.masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0))
231
+
232
+ # Pass through transformer encoder
233
+ transformer_output = self.transformer(full_sequence, mask=attention_mask)
234
+
235
+ # Extract the prediction hidden state and compute logits
236
+ prediction_hidden_state = transformer_output[-1, :, :] # Shape: (batch_size, hidden_dim)
237
+ # Calculate final logits
238
+ logits = self.fc(prediction_hidden_state) # Shape: (batch_size, num_classes)
239
+
240
+ return logits
PictSure/model_ViT.py ADDED
@@ -0,0 +1,162 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import datasets, transforms
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ class PatchEmbed(nn.Module):
11
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
12
+ super().__init__()
13
+ self.img_size = img_size
14
+ self.patch_size = patch_size
15
+ self.grid_size = img_size // patch_size # e.g. 14 if 224 // 16
16
+ self.num_patches = self.grid_size ** 2
17
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
18
+
19
+ def forward(self, x):
20
+ # x: [B, 3, H, W]
21
+ # project to embeddings with shape [B, D, #patches_row, #patches_col]
22
+ x = self.proj(x) # -> [B, embed_dim, grid_size, grid_size]
23
+ # flatten the spatial dims
24
+ x = x.flatten(2) # -> [B, embed_dim, grid_size*grid_size]
25
+ x = x.transpose(1, 2) # -> [B, #patches, embed_dim]
26
+ return x
27
+
28
+ class Attention(nn.Module):
29
+ def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
30
+ super().__init__()
31
+ self.num_heads = num_heads
32
+ head_dim = dim // num_heads
33
+ self.scale = head_dim ** -0.5
34
+
35
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
36
+ self.attn_drop = nn.Dropout(attn_drop)
37
+ self.proj = nn.Linear(dim, dim)
38
+ self.proj_drop = nn.Dropout(proj_drop)
39
+
40
+ def forward(self, x):
41
+ B, N, C = x.shape
42
+ qkv = self.qkv(x) # -> [B, N, 3*C]
43
+ qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
44
+ qkv = qkv.permute(2, 0, 3, 1, 4) # -> [3, B, heads, N, C//heads]
45
+ q, k, v = qkv[0], qkv[1], qkv[2]
46
+
47
+ # scaled dot product
48
+ attn = (q @ k.transpose(-2, -1)) * self.scale # [B, heads, N, N]
49
+ attn = attn.softmax(dim=-1)
50
+ attn = self.attn_drop(attn)
51
+
52
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
53
+ x = self.proj(x)
54
+ x = self.proj_drop(x)
55
+ return x
56
+
57
+ class MLP(nn.Module):
58
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
59
+ super().__init__()
60
+ out_features = out_features or in_features
61
+ hidden_features = hidden_features or in_features
62
+ self.fc1 = nn.Linear(in_features, hidden_features)
63
+ self.act = nn.GELU()
64
+ self.fc2 = nn.Linear(hidden_features, out_features)
65
+ self.drop = nn.Dropout(drop)
66
+
67
+ def forward(self, x):
68
+ x = self.fc1(x)
69
+ x = self.act(x)
70
+ x = self.drop(x)
71
+ x = self.fc2(x)
72
+ x = self.drop(x)
73
+ return x
74
+
75
+ class Block(nn.Module):
76
+ def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True,
77
+ drop=0.0, attn_drop=0.0):
78
+ super().__init__()
79
+ self.norm1 = nn.LayerNorm(dim)
80
+ self.attn = Attention(
81
+ dim, num_heads=num_heads, qkv_bias=qkv_bias,
82
+ attn_drop=attn_drop, proj_drop=drop
83
+ )
84
+ self.norm2 = nn.LayerNorm(dim)
85
+ self.mlp = MLP(
86
+ in_features=dim, hidden_features=int(dim*mlp_ratio),
87
+ out_features=dim, drop=drop
88
+ )
89
+
90
+ def forward(self, x):
91
+ x = x + self.attn(self.norm1(x))
92
+ x = x + self.mlp(self.norm2(x))
93
+ return x
94
+
95
+ class VisionTransformer(nn.Module):
96
+ def __init__(
97
+ self,
98
+ img_size=224,
99
+ patch_size=16,
100
+ in_chans=3,
101
+ num_classes=1000,
102
+ embed_dim=768,
103
+ depth=12,
104
+ num_heads=12,
105
+ mlp_ratio=4.0,
106
+ qkv_bias=True,
107
+ drop_rate=0.0,
108
+ attn_drop_rate=0.0
109
+ ):
110
+ super().__init__()
111
+ self.num_classes = num_classes
112
+ self.embed_dim = embed_dim
113
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
114
+ self.num_patches = self.patch_embed.num_patches
115
+
116
+ # CLS token
117
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
118
+ # 1D positional embedding
119
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches+1, embed_dim))
120
+ self.pos_drop = nn.Dropout(p=drop_rate)
121
+
122
+ # Transformer blocks
123
+ self.blocks = nn.ModuleList([
124
+ Block(embed_dim, num_heads, mlp_ratio,
125
+ qkv_bias, drop_rate, attn_drop_rate)
126
+ for _ in range(depth)
127
+ ])
128
+ self.norm = nn.LayerNorm(embed_dim)
129
+
130
+ # Classifier head
131
+ self.head = nn.Linear(embed_dim, num_classes)
132
+
133
+ # Weight initialization
134
+ self._init_weights()
135
+
136
+ def _init_weights(self):
137
+ # simple initialization
138
+ torch.nn.init.normal_(self.pos_embed, std=0.02)
139
+ torch.nn.init.normal_(self.cls_token, std=0.02)
140
+ torch.nn.init.xavier_uniform_(self.head.weight)
141
+ torch.nn.init.normal_(self.head.bias, std=1e-6)
142
+
143
+ def forward(self, x):
144
+ # x shape: [B, 3, H, W]
145
+ B = x.shape[0]
146
+ x = self.patch_embed(x) # -> [B, N, D]
147
+ cls_tokens = self.cls_token.expand(B, -1, -1) # -> [B, 1, D]
148
+ x = torch.cat((cls_tokens, x), dim=1) # -> [B, N+1, D]
149
+
150
+ x = x + self.pos_embed[:, :(x.size(1)), :]
151
+ x = self.pos_drop(x)
152
+
153
+ for blk in self.blocks:
154
+ x = blk(x)
155
+
156
+ x = self.norm(x)
157
+ # extract CLS token
158
+ cls_token_final = x[:, 0]
159
+ # classification
160
+ logits = self.head(cls_token_final)
161
+
162
+ return logits, cls_token_final
@@ -0,0 +1,47 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ from .model_ViT import VisionTransformer
5
+
6
+ def load_encoder(device="cpu"):
7
+ base_model = models.resnet18(pretrained=True)
8
+ encoder = ResNetWrapper(base_model).to(device)
9
+ return encoder
10
+
11
+ class ResNetWrapper(nn.Module):
12
+ def __init__(self, classifier):
13
+ super(ResNetWrapper, self).__init__()
14
+ self.feature_extractor = nn.Sequential(*list(classifier.children())[:-1], torch.nn.Flatten())
15
+ self.latent_dim = self.feature_extractor(torch.zeros(1, 3, 224, 224)).shape[-1]
16
+
17
+ def forward(self, x):
18
+ num_images = x.size(1)
19
+ batch_size = x.size(0)
20
+ x = x.view(-1, 3, 224, 224)
21
+ x = self.feature_extractor(x)
22
+ x = x.view(batch_size, num_images, self.latent_dim)
23
+ return x
24
+
25
+ @property
26
+ def device(self):
27
+ return next(self.parameters()).device
28
+
29
+ class VitNetWrapper(nn.Module):
30
+ def __init__(self, path, num_classes=1000):
31
+ super().__init__()
32
+ self.embedding = VisionTransformer(num_classes=num_classes)
33
+ if path:
34
+ self.embedding.load_state_dict(torch.load(path))
35
+ self.latent_dim = self.embedding.embed_dim
36
+
37
+ def forward(self, x):
38
+ num_images = x.size(1)
39
+ batch_size = x.size(0)
40
+ x = x.view(-1, 3, 224, 224)
41
+ x = self.embedding.forward(x)[1]
42
+ x = x.view(batch_size, num_images, self.latent_dim)
43
+ return x
44
+
45
+ @property
46
+ def device(self):
47
+ return next(self.parameters()).device
@@ -0,0 +1,148 @@
1
+ Metadata-Version: 2.4
2
+ Name: PictSure
3
+ Version: 0.1.0
4
+ Summary: A package for generalized image classification using In-Context-Learning with PyTorch.
5
+ Author-email: Cornelius Wolff <cornelius.wolff@cwi.nl>, Lukas Schiesser <lukas.schiesser@dfki.de>
6
+ License: MIT License
7
+
8
+ Copyright (c) 2025 Cornelius Wolff; Lukas Schiesser
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Classifier: Programming Language :: Python :: 3
29
+ Classifier: License :: OSI Approved :: MIT License
30
+ Classifier: Operating System :: OS Independent
31
+ Requires-Python: >=3.9
32
+ Description-Content-Type: text/markdown
33
+ License-File: LICENSE
34
+ Requires-Dist: torch>=2.7.0
35
+ Requires-Dist: torchvision>=0.22.0
36
+ Requires-Dist: numpy>=1.26.4
37
+ Requires-Dist: Pillow
38
+ Requires-Dist: click>=8.1.7
39
+ Requires-Dist: tqdm>=4.66.4
40
+ Requires-Dist: requests>=2.32.3
41
+ Requires-Dist: huggingface-hub>=0.33.1
42
+ Requires-Dist: safetensors>=0.5.3
43
+ Dynamic: license-file
44
+
45
+ # PictSure: In-Context Learning for Image Classification
46
+
47
+ PictSure is a deep learning library designed for **in-context learning** using images and labels. It allows users to provide a set of labeled reference images and then predict labels for new images based on those references. This approach eliminates the need for traditional training, making it highly adaptable for various classification tasks.
48
+
49
+ <p align="center">
50
+ <img src="images/Flow-Chart.png" alt="The classification process" width="90%" />
51
+ </p>
52
+
53
+ ## Features
54
+ - **In-Context Learning**: Predict labels for new images using a set of reference images without traditional model training.
55
+ - **Multiple Model Architectures**: Choose between ResNet and ViT-based models for your specific needs.
56
+ - **Pretrained Models**: Use our pretrained models or train your own.
57
+ - **Torch Compatibility**: Fully integrated with PyTorch, supporting CPU and GPU.
58
+ - **Easy-to-use CLI**: Manage models and weights through a simple command-line interface.
59
+
60
+ ## Installation
61
+ 1. Clone this repository
62
+ ```bash
63
+ git clone https://git.ni.dfki.de/pictsure/pictsure-library
64
+ ```
65
+ 2. Navigate into the folder
66
+ ```bash
67
+ cd pictsure-library
68
+ ```
69
+ 3. Install the pip package
70
+ ```bash
71
+ pip install .
72
+ ```
73
+
74
+ ## Quick Start
75
+ ```python
76
+ from PictSure import PictSure
77
+ import torch
78
+
79
+ # Initialize the model (using ViT as an example)
80
+ model = PictSure(
81
+ embedding='vit', # or 'resnet'
82
+ pretrained=True, # use pretrained weights
83
+ device='cuda' # or 'cpu'
84
+ )
85
+
86
+ # you can also pull our pre-trained models from Huggingface
87
+ model = PictSure.from_pretrained("pictsure/pictsure-vit")
88
+
89
+ # Set your reference images and labels
90
+ model.set_context_images(reference_images, reference_labels)
91
+
92
+ # Make predictions on new images
93
+ predictions = model.predict(new_images)
94
+ ```
95
+
96
+ ## Command Line Interface
97
+ PictSure comes with a command-line interface to manage models and weights:
98
+
99
+ ### List Available Models
100
+ ```bash
101
+ pictsure list-models
102
+ ```
103
+ This command shows all available models, their status (downloaded/not downloaded), and detailed information about each model.
104
+
105
+ ### Remove Model Weights
106
+ ```bash
107
+ pictsure remove <model_name> [--force]
108
+ ```
109
+ Remove the weights of a specific model. Available models are:
110
+ - `ViTPreAll`: ViT-based model
111
+ - `ResPreAll`: ResNet-based model
112
+
113
+ Use the `--force` or `-f` flag to skip the confirmation prompt.
114
+
115
+ ## Examples
116
+ For a complete working example, check out the Jupyter notebook in the Examples directory:
117
+ ```bash
118
+ Examples/example.ipynb
119
+ ```
120
+ This notebook demonstrates:
121
+ - Model initialization
122
+ - Loading and preprocessing images
123
+ - Setting up reference images
124
+ - Making predictions
125
+ - Visualizing results
126
+
127
+ ## Citation
128
+
129
+ If you use this work, please cite it using the following BibTeX entry:
130
+
131
+ ```bibtex
132
+ @article{schiesser2025pictsure,
133
+ title={PictSure: Pretraining Embeddings Matters for In-Context Learning Image Classifiers},
134
+ author={Schiesser, Lukas and Wolff, Cornelius and Haas, Sophie and Pukrop, Simon},
135
+ journal={arXiv preprint arXiv:2506.14842},
136
+ year={2025}
137
+ }
138
+ ```
139
+
140
+ ## License
141
+ This project is open-source under the MIT License.
142
+
143
+ ## Contributing
144
+ Contributions and suggestions are welcome! Open an issue or submit a pull request.
145
+
146
+ ## Contact
147
+ For questions or support, open an issue on GitHub.
148
+
@@ -0,0 +1,12 @@
1
+ Examples/example.py,sha256=kJl9EOYegRYnv_Q3enzxRw3_-1JahuXrBJlMKQ1lI9c,2480
2
+ PictSure/__init__.py,sha256=JD4sCnhwmPdDvCv4D2Tm5ZlrTVY7Yi1jYNI91hYKR2Q,154
3
+ PictSure/cli.py,sha256=mchQYMF-PXr9QhYGP05_EUkZRy99KB8NYx-Cfz01IAA,3317
4
+ PictSure/model_PictSure.py,sha256=hAv2Wc7N1sX-OgPB1aTo9lJxc9uEu-unvqTH8tzcgXU,10727
5
+ PictSure/model_ViT.py,sha256=7XXgyyZrT5v_1ReTwaCvR4EJ8VjXLBrmbPElu4reDMc,5372
6
+ PictSure/model_embeddings.py,sha256=o8_T-JE7dOUB7QgTQHMIzjCbNSGQux3css84lZTNxTw,1531
7
+ pictsure-0.1.0.dist-info/licenses/LICENSE,sha256=EWEw5rrEDvPxG3Wz_TSZJoGJ3J9k1Rv6yMluaBojABc,1089
8
+ pictsure-0.1.0.dist-info/METADATA,sha256=RVV2j6gRZI7dfwyjYSPrvyFTf-Yoa-kFM7gOZ7TwcCg,5270
9
+ pictsure-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ pictsure-0.1.0.dist-info/entry_points.txt,sha256=TROsY1gBQxYsQfyNHKCSwDF6sxsFJYWdRguAtR1O1ec,46
11
+ pictsure-0.1.0.dist-info/top_level.txt,sha256=4c6FfUQfr4v2hzAizS1iifVQaGVSLWweO2DICgcIbe4,18
12
+ pictsure-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ pictsure = PictSure.cli:cli
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Cornelius Wolff; Lukas Schiesser
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,2 @@
1
+ Examples
2
+ PictSure