lora-nf 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lora_nf-0.1.1/PKG-INFO ADDED
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.4
2
+ Name: lora-nf
3
+ Version: 0.1.1
4
+ Summary: Low-Rank Adaptation of Neural Fields
5
+ Home-page: https://github.com/dinhanhtruong/LoRA-NF
6
+ Author: Anh Truong
7
+ License: MIT License
8
+ Classifier: Programming Language :: Python :: 3.12
9
+ Requires-Dist: torch
10
+ Requires-Dist: numpy
11
+ Requires-Dist: imageio
12
+ Requires-Dist: trimesh
13
+ Requires-Dist: pysdf
14
+ Requires-Dist: PyMCubes
15
+ Requires-Dist: packaging
16
+ Dynamic: author
17
+ Dynamic: classifier
18
+ Dynamic: home-page
19
+ Dynamic: license
20
+ Dynamic: requires-dist
21
+ Dynamic: summary
@@ -0,0 +1,2 @@
1
+ # LoRA-NF
2
+ TODO
@@ -0,0 +1 @@
1
+ __version__ = "0.1.1"
@@ -0,0 +1,151 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Redistribution and use in source and binary forms, with or without modification, are permitted
6
+ # provided that the following conditions are met:
7
+ # * Redistributions of source code must retain the above copyright notice, this list of
8
+ # conditions and the following disclaimer.
9
+ # * Redistributions in binary form must reproduce the above copyright notice, this list of
10
+ # conditions and the following disclaimer in the documentation and/or other materials
11
+ # provided with the distribution.
12
+ # * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
13
+ # to endorse or promote products derived from this software without specific prior written
14
+ # permission.
15
+ #
16
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
17
+ # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
18
+ # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
19
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
20
+ # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21
+ # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
22
+ # STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+
25
+ import imageio
26
+ import numpy as np
27
+ import os
28
+ import struct
29
+
30
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
31
+
32
+ def mse2psnr(x):
33
+ return -10.*np.log(x)/np.log(10.)
34
+
35
+ def write_image_imageio(img_file, img, quality):
36
+ img = (np.clip(img, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) #
37
+ kwargs = {}
38
+ if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]:
39
+ if img.ndim >= 3 and img.shape[2] > 3:
40
+ img = img[:,:,:3]
41
+ if img.ndim >= 3 and img.shape[2] == 1: # grayscale
42
+ img = np.squeeze(img)
43
+ kwargs["quality"] = quality
44
+ kwargs["subsampling"] = 0
45
+ imageio.imwrite(img_file, img, **kwargs)
46
+
47
+ def read_image_imageio(img_file):
48
+ img = imageio.imread(img_file)
49
+ img = np.asarray(img).astype(np.float32)
50
+ if len(img.shape) == 2:
51
+ img = img[:,:,np.newaxis]
52
+ return img / 255.0
53
+
54
+ def srgb_to_linear(img):
55
+ limit = 0.04045
56
+ return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)
57
+
58
+ def linear_to_srgb(img):
59
+ limit = 0.0031308
60
+ return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)
61
+
62
+ def read_image(file):
63
+ if os.path.splitext(file)[1] == ".bin":
64
+ with open(file, "rb") as f:
65
+ bytes = f.read()
66
+ h, w = struct.unpack("ii", bytes[:8])
67
+ img = np.frombuffer(bytes, dtype=np.float16, count=h*w*4, offset=8).astype(np.float32).reshape([h, w, 4])
68
+ else:
69
+ img = read_image_imageio(file)
70
+ if img.shape[2] == 4:
71
+ img[...,0:3] = srgb_to_linear(img[...,0:3])
72
+ # Premultiply alpha
73
+ img[...,0:3] *= img[...,3:4]
74
+ else:
75
+ img = srgb_to_linear(img)
76
+ return img
77
+
78
+ def write_image(file, img, quality=100):
79
+ if os.path.splitext(file)[1] == ".bin":
80
+ if img.shape[2] < 4:
81
+ img = np.dstack((img, np.ones([img.shape[0], img.shape[1], 4 - img.shape[2]])))
82
+ with open(file, "wb") as f:
83
+ f.write(struct.pack("ii", img.shape[0], img.shape[1]))
84
+ f.write(img.astype(np.float16).tobytes())
85
+ else:
86
+ if img.shape[2] == 4:
87
+ img = np.copy(img)
88
+ # Unmultiply alpha
89
+ img[...,0:3] = np.divide(img[...,0:3], img[...,3:4], out=np.zeros_like(img[...,0:3]), where=img[...,3:4] != 0)
90
+ img[...,0:3] = linear_to_srgb(img[...,0:3])
91
+ else:
92
+ img = linear_to_srgb(img)
93
+ write_image_imageio(file, img, quality)
94
+
95
+ def trim(error, skip=0.000001):
96
+ error = np.sort(error.flatten())
97
+ size = error.size
98
+ skip = int(skip * size)
99
+ return error[skip:size-skip].mean()
100
+
101
+ def luminance(a):
102
+ a = np.maximum(0, a)**0.4545454545
103
+ return 0.2126 * a[:,:,0] + 0.7152 * a[:,:,1] + 0.0722 * a[:,:,2]
104
+
105
+ def L1(img, ref):
106
+ return np.abs(img - ref)
107
+
108
+ def APE(img, ref):
109
+ return L1(img, ref) / (1e-2 + ref)
110
+
111
+ def SAPE(img, ref):
112
+ return L1(img, ref) / (1e-2 + (ref + img) / 2.)
113
+
114
+ def L2(img, ref):
115
+ return (img - ref)**2
116
+
117
+ def RSE(img, ref):
118
+ return L2(img, ref) / (1e-2 + ref**2)
119
+
120
+ def rgb_mean(img):
121
+ return np.mean(img, axis=2)
122
+
123
+ def compute_error_img(metric, img, ref):
124
+ img[np.logical_not(np.isfinite(img))] = 0
125
+ img = np.maximum(img, 0.)
126
+ if metric == "MAE":
127
+ return L1(img, ref)
128
+ elif metric == "MAPE":
129
+ return APE(img, ref)
130
+ elif metric == "SMAPE":
131
+ return SAPE(img, ref)
132
+ elif metric == "MSE":
133
+ return L2(img, ref)
134
+ elif metric == "MScE":
135
+ return L2(np.clip(img, 0.0, 1.0), np.clip(ref, 0.0, 1.0))
136
+ elif metric == "MRSE":
137
+ return RSE(img, ref)
138
+ elif metric == "MtRSE":
139
+ return trim(RSE(img, ref))
140
+ elif metric == "MRScE":
141
+ return RSE(np.clip(img, 0, 100), np.clip(ref, 0, 100))
142
+
143
+ raise ValueError(f"Unknown metric: {metric}.")
144
+
145
+ def compute_error(metric, img, ref):
146
+ metric_map = compute_error_img(metric, img, ref)
147
+ metric_map[np.logical_not(np.isfinite(metric_map))] = 0
148
+ if len(metric_map.shape) == 3:
149
+ metric_map = np.mean(metric_map, axis=2)
150
+ mean = np.mean(metric_map)
151
+ return mean
@@ -0,0 +1,169 @@
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from .common import read_image, write_image
5
+ import trimesh
6
+ from .util import save_mesh
7
+ import pysdf
8
+
9
+ class DataSampler:
10
+ def sample_batch(self, n_samples: int, device: torch.device):
11
+ """
12
+ samples input positions and corresponding field values
13
+
14
+ Args
15
+ - n_samples: number of samples
16
+ - device: torch device (e.g. "cpu", "cuda:0") on which the sampled data will live
17
+
18
+ Returns:
19
+ - inputs: sample positions. torch tensor of shape [n_samples, data_input_dim], e.g., [n_samples, 2] for image xy positions
20
+ - targets: values of the field sampled at inputs. torch tensor of shape [n_samples, data_output_dim], e.g., [n_samples, 3] for image RGB values
21
+ """
22
+ raise NotImplementedError
23
+ def save_model_output(self, model: torch.nn.Module, save_path: str):
24
+ """
25
+ saves a reconstruction of the given neural field at the given path (without extension). assumes that the directory exists
26
+
27
+ Args:
28
+ - model: a neural field nn.Module()
29
+
30
+ Returns nothing
31
+ """
32
+ raise NotImplementedError
33
+
34
+ class Image(DataSampler, torch.nn.Module):
35
+ def __init__(self, filename: str, device: torch.device):
36
+ """
37
+ Args
38
+ - filename: path to image
39
+ - device: torch.device
40
+ """
41
+ super().__init__()
42
+ self.data = read_image(filename)
43
+ # remove alpha channel
44
+ if self.data.ndim > 2 and self.data.shape[2] == 4:
45
+ self.data = self.data[:,:,:3] # keep RGB
46
+ self.orig_data_npy = self.data
47
+ self.shape = self.data.shape
48
+ self.data = torch.from_numpy(self.data).float().to(device)
49
+ self.device = device
50
+
51
+ # for model output reconsturction
52
+ resolution = self.data.shape[0:2]
53
+ # n_pixels = resolution[0] * resolution[1]
54
+ half_dx = 0.5 / resolution[0] # half pixel size
55
+ half_dy = 0.5 / resolution[1]
56
+ xs = torch.linspace(half_dx, 1-half_dx, resolution[0])
57
+ ys = torch.linspace(half_dy, 1-half_dy, resolution[1])
58
+ xv, yv = torch.meshgrid([xs, ys], indexing="ij")
59
+ self.img_shape = resolution + torch.Size([self.data.shape[2]])
60
+ self.xy = torch.stack((yv.flatten(), xv.flatten())).t().to(device)
61
+
62
+ def forward(self, xs, interpolate=True):
63
+ with torch.no_grad():
64
+ # Bilinearly filtered lookup from the image. Not super fast,
65
+ # but less than ~20% of the overall runtime of this example.
66
+ shape = self.shape
67
+
68
+ xs = xs * torch.tensor([shape[1], shape[0]], device=xs.device).float()
69
+ indices = xs.long()
70
+
71
+ x0 = indices[:, 0].clamp(min=0, max=shape[1]-1)
72
+ y0 = indices[:, 1].clamp(min=0, max=shape[0]-1)
73
+ if interpolate:
74
+ lerp_weights = xs - indices.float()
75
+ x1 = (x0 + 1).clamp(max=shape[1]-1)
76
+ y1 = (y0 + 1).clamp(max=shape[0]-1)
77
+
78
+ return (
79
+ self.data[y0, x0] * (1.0 - lerp_weights[:,0:1]) * (1.0 - lerp_weights[:,1:2]) +
80
+ self.data[y0, x1] * lerp_weights[:,0:1] * (1.0 - lerp_weights[:,1:2]) +
81
+ self.data[y1, x0] * (1.0 - lerp_weights[:,0:1]) * lerp_weights[:,1:2] +
82
+ self.data[y1, x1] * lerp_weights[:,0:1] * lerp_weights[:,1:2]
83
+ )
84
+
85
+ return (self.data[y0, x0]) # no interpolation
86
+
87
+ def sample_batch(self, n_samples, device):
88
+ assert device == self.device
89
+ input_xy = torch.rand([n_samples, 2], dtype=torch.float, device=device)
90
+ image_rgb = self.forward(input_xy)
91
+ return input_xy, image_rgb
92
+
93
+ def save_model_output(self, model, save_path):
94
+ write_image(f"{save_path}.png", model(self.xy).reshape(self.img_shape).clamp(0.0, 1.0).detach().cpu().numpy())
95
+
96
+
97
+ class SDF(DataSampler):
98
+ def __init__(self, path: str, device: torch.device, num_samples=2**18, clip_sdf=None, transformation_save_dir=""):
99
+ super().__init__()
100
+ self.path = path
101
+ self.device = self.device
102
+
103
+ # load obj
104
+ self.mesh = trimesh.load(path, force='mesh')
105
+
106
+ # normalize to [-1, 1] (different from instant-sdf where is [0, 1]) via scaling and translation
107
+ vs = self.mesh.vertices
108
+ vmin = vs.min(0)
109
+ vmax = vs.max(0)
110
+ v_center = (vmin + vmax) / 2
111
+ v_scale = 2 / np.sqrt(np.sum((vmax - vmin) ** 2)) * 0.95
112
+
113
+ # TODO: save normalizing transformation of base model and apply same transformation to deformed model
114
+ transformation_path = f"{transformation_save_dir}/data_normalization_transformation.npz"
115
+ if not os.path.exists(transformation_path):
116
+ print("saving base mesh normalization transformation")
117
+ np.savez(transformation_path, v_scale=v_scale, v_center=v_center)
118
+ else:
119
+ print("######################")
120
+ print("####### TEMP DISABLED CONSISTENT SDF NORMALIZATION")
121
+ print("######################")
122
+ print("######################")
123
+ print("######################")
124
+ print("######################")
125
+ # print("using existing normalization transformation")
126
+ # loaded_transformation = np.load(transformation_path)
127
+ # v_scale = loaded_transformation["v_scale"]
128
+ # v_center = loaded_transformation["v_center"]
129
+
130
+ print("scale: ", v_scale)
131
+ print("center: ", v_center)
132
+ # apply transformation to verts
133
+ vs = (vs - v_center[None, :]) * v_scale
134
+ self.mesh.vertices = vs
135
+
136
+ print(f"[INFO] mesh verts & faces: {self.mesh.vertices.shape} & {self.mesh.faces.shape}")
137
+
138
+ if not self.mesh.is_watertight:
139
+ print(f"[WARN] mesh is not watertight! SDF maybe incorrect.")
140
+ #trimesh.Scene([self.mesh]).show()
141
+
142
+ self.sdf_fn = pysdf.SDF(self.mesh.vertices, self.mesh.faces)
143
+
144
+ self.num_samples = num_samples
145
+ assert self.num_samples % 8 == 0, "num_samples must be divisible by 8."
146
+ self.clip_sdf = clip_sdf
147
+
148
+ def sample_batch(self, n_samples, device):
149
+ assert device == self.device
150
+ # online sampling
151
+ sdfs = np.zeros((n_samples, 1))
152
+ # surface query points (7/8 points for surface and near-surface)
153
+ points_surface = self.mesh.sample(n_samples * 7 // 8)
154
+
155
+ # near-surface points
156
+ points_surface[n_samples // 2:] += 0.01 * np.random.randn(n_samples * 3 // 8, 3)
157
+ # random uniform points (1/8 of points)
158
+ points_uniform = np.random.rand(n_samples // 8, 3) * 2 - 1
159
+ points = np.concatenate([points_surface, points_uniform], axis=0).astype(np.float32)
160
+ sdfs[n_samples // 2:] = -self.sdf_fn(points[n_samples // 2:])[:,None].astype(np.float32)
161
+
162
+ # clip sdf
163
+ if self.clip_sdf is not None:
164
+ sdfs = sdfs.clip(-self.clip_sdf, self.clip_sdf)
165
+
166
+ return torch.tensor(points, device=device), torch.tensor(sdfs, device=device)
167
+
168
+ def save_model_output(self, model: torch.nn.Module, save_path: str):
169
+ save_mesh(save_path, model, self.device, resolution=256)
@@ -0,0 +1,122 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class CustomFrequencyEncoding(nn.Module):
6
+ def __init__(self):
7
+ super(CustomFrequencyEncoding, self).__init__()
8
+ def get_encoding_output_dim(self, input_dim):
9
+ return self.forward(torch.zeros((1, input_dim))).shape[-1]
10
+
11
+ def forward( # from https://github.com/krrish94/nerf-pytorch/blob/master/nerf/nerf_helpers.py
12
+ self, tensor, num_encoding_functions=10, include_input=True, log_sampling=True
13
+ ) -> torch.Tensor:
14
+ r"""Apply positional encoding to the input.
15
+
16
+ Args:
17
+ tensor (torch.Tensor): Input tensor to be positionally encoded.
18
+ encoding_size (optional, int): Number of encoding functions used to compute
19
+ a positional encoding (default: 6).
20
+ include_input (optional, bool): Whether or not to include the input in the
21
+ positional encoding (default: True).
22
+
23
+ Returns:
24
+ (torch.Tensor): Positional encoding of the input tensor.
25
+ """
26
+ encoding = [tensor] if include_input else []
27
+ frequency_bands = None
28
+ if log_sampling:
29
+ frequency_bands = 2.0 ** torch.linspace(
30
+ 0.0,
31
+ num_encoding_functions - 1,
32
+ num_encoding_functions,
33
+ dtype=tensor.dtype,
34
+ device=tensor.device,
35
+ )
36
+ else:
37
+ frequency_bands = torch.linspace(
38
+ 2.0 ** 0.0,
39
+ 2.0 ** (num_encoding_functions - 1),
40
+ num_encoding_functions,
41
+ dtype=tensor.dtype,
42
+ device=tensor.device,
43
+ )
44
+
45
+ for freq in frequency_bands:
46
+ for func in [torch.sin, torch.cos]:
47
+ encoding.append(func(tensor * freq))
48
+
49
+ # Special case, for no positional encoding
50
+ if len(encoding) == 1:
51
+ return encoding[0]
52
+ else:
53
+ return torch.cat(encoding, dim=-1)
54
+
55
+ class LoRA_MLP(nn.Module):
56
+ def __init__(self, base_mlp, rank):
57
+ '''
58
+ LoRA-augmented MLP
59
+
60
+ prev_frame_lora_paths: list of chronologically ordered filepaths to prev frames' LoraMLP weights
61
+ '''
62
+ super(LoRA_MLP, self).__init__()
63
+
64
+ # for each parent MLP linear layer, store a low-rank adaptor (LoRALinear layer).
65
+ # mimic parent MLP but replace linear with LoRALinear
66
+ self.sequential = nn.Sequential() # just LoRALinears
67
+
68
+ for base_layer in base_mlp:
69
+ if isinstance(base_layer, nn.Linear):
70
+ # add lora layer
71
+ self.sequential.append(LoRALinear(base_layer, r=rank))
72
+ else:
73
+ # keep parent layer (activation or positional encoding)
74
+ self.sequential.append(base_layer)
75
+
76
+ def as_sequential(self):
77
+ return self.sequential
78
+ def get_lora_weights(self):
79
+ """
80
+ Returns a list of tensors of weights corresponding to the current LoRAs (in the same order as the base model's)
81
+ """
82
+ lora_weights = []
83
+ for lora_mlp_layer in self.sequential:
84
+ if isinstance(lora_mlp_layer, LoRALinear):
85
+ lora_weights.append((lora_mlp_layer.A @ lora_mlp_layer.B).T) # [in_features, out_features]
86
+ return lora_weights
87
+
88
+ def forward(self, x):
89
+ return self.sequential(x)
90
+
91
+
92
+ class LoRALinear(nn.Module):
93
+ def __init__(self, base_linear, r=16, alpha=1, A=None):
94
+ '''
95
+
96
+ '''
97
+ super(LoRALinear, self).__init__()
98
+ assert isinstance(base_linear, nn.Linear) and r > 1
99
+ self.base_linear = base_linear
100
+ # freeze base; make sure only lora weights are trainable
101
+ self.base_linear.weight.requires_grad_(False)
102
+ if self.base_linear.bias is not None:
103
+ self.base_linear.bias.requires_grad_(False)
104
+
105
+ rank = min(r, base_linear.in_features, base_linear.out_features)
106
+ if A is not None:
107
+ # use specified A matrix
108
+ assert A.shape == torch.Size((base_linear.in_features, rank))
109
+ self.A = A
110
+ else:
111
+ self.A = nn.Parameter(torch.empty((base_linear.in_features, rank)))
112
+ nn.init.normal_(self.A, std=1/math.sqrt(base_linear.in_features)) #following pg4 of https://arxiv.org/pdf/2406.08447
113
+ self.B = nn.Parameter(torch.zeros(rank, base_linear.out_features))
114
+ # lora scale factor
115
+ self.scaling = alpha/r
116
+
117
+ def forward(self, x):
118
+ return self.base_linear(x) + torch.linalg.multi_dot((x, self.A, self.B))*self.scaling
119
+
120
+
121
+ def extract_linear_layers(module):
122
+ return [m for m in module.modules() if isinstance(m, nn.Linear)]
@@ -0,0 +1,198 @@
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import time
5
+ import torch
6
+ import copy
7
+ import torch.nn as nn
8
+ from .torch_modules import LoRA_MLP, extract_linear_layers
9
+ from .util import get_device
10
+ from .data_samplers import DataSampler
11
+ from typing import Callable
12
+
13
+ torch.manual_seed(0)
14
+ np.random.seed(0)
15
+ random.seed(0)
16
+
17
+
18
+ def train_lora_regression(
19
+ base_nf: nn.Sequential,
20
+ target_sampler: DataSampler,
21
+ loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
22
+ lora_rank: int,
23
+ learning_rate=5e-3,
24
+ batch_size=2**18,
25
+ max_n_steps=30000,
26
+ lr_scheduler_warmup_steps=7000,
27
+ log_interval=100,
28
+ convergence_patience=15,
29
+ save_dir="",
30
+ ):
31
+ """
32
+ Trains LoRAs for every linear layer of the given base neural field. Returns a new neural field with a LoRA applied and (optionally) the weights of each LoRA
33
+
34
+ Args:
35
+ - base_nf: MLP implemented as a nn.Sequential (containing nn.Linear layers, activation functions, optional input positional encoding)
36
+ - target_sampler: a DataSampler that must be compatible with base_nf (i.e. input/target dimensions must match)
37
+ - loss_fn: callable function (output,target) |-> loss scalar
38
+ - lora_rank: Desired maximum rank of each LoRA
39
+ - learning_rate: for ADAM optimizer
40
+ - batch_size: num samples per step
41
+ - max_n_steps: max number of training steps
42
+ - lr_scheduler_warmup_steps: number of warmup steps for the learning rate scheduler. No warmup if 0.
43
+ - convergence_patience: number of log_intervals of no improvement after which training is terminated early
44
+ - save_dir: if provided, then weights and output reconstructions will be saved there
45
+
46
+ Returns:
47
+ - lora_weights: LoRA weight tensors, one per linear layer of the base model. Returns the iterate with the lowest loss.
48
+ - lora_nf: nn.Sequential neural field with LoRA applied to every linear layer. Returns the iterate with the lowest loss.
49
+ """
50
+ # make sure that base_nf and target_sampler are compatible
51
+ device = get_device(base_nf)
52
+ linear_layers = extract_linear_layers(base_nf)
53
+ base_output_dim = linear_layers[-1].out_features
54
+ _, dummy_targets = target_sampler.sample_batch(batch_size, device)
55
+ assert dummy_targets.shape[-1] == base_output_dim, "base_nf and target_sampler are incompatible: output size mismatch"
56
+ assert isinstance(base_nf, nn.Sequential)
57
+
58
+ # set up lora neural field
59
+ lora_nf = LoRA_MLP(base_nf, lora_rank)
60
+ lora_nf.to(device)
61
+ print(f"training LoRA on {device}")
62
+ trainable_params = sum(p.numel() for p in lora_nf.parameters() if p.requires_grad)
63
+ print(f"# trainable LoRA parameters: {trainable_params}")
64
+
65
+ optimizer = torch.optim.Adam(lora_nf.parameters(), lr=learning_rate) # only fine-tune mlp
66
+ # use learning rate scheduler if specified
67
+ scheduler = None
68
+ if lr_scheduler_warmup_steps > 0:
69
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=lr_scheduler_warmup_steps)
70
+ # for logging
71
+ best_loss = float('inf')
72
+ best_lora_nf = copy.deepcopy(lora_nf)
73
+ prev_time = time.perf_counter()
74
+ periods_no_improve = 0 # for early stopping
75
+ if save_dir:
76
+ os.makedirs(save_dir, exist_ok=True)
77
+
78
+ for i in range(max_n_steps):
79
+ batch_inputs, batch_targets = target_sampler.sample_batch(batch_size, device) # [B,in], [B,out]
80
+ output = lora_nf(batch_inputs.to(device))
81
+ loss = loss_fn(output, batch_targets.to(device))
82
+
83
+ optimizer.zero_grad()
84
+ loss.backward()
85
+ optimizer.step()
86
+ if scheduler:
87
+ scheduler.step()
88
+
89
+ if i % log_interval == 0: # reconstruct output
90
+ curr_loss = loss.item()
91
+ elapsed_time = time.perf_counter() - prev_time
92
+ print(f"Step#{i}: loss={curr_loss:.7f} time={int(elapsed_time)}[s]")
93
+ if save_dir:
94
+ with torch.no_grad():
95
+ target_sampler.save_model_output(lora_nf, save_path=f"{save_dir}/lora_nf_step_{i:05d}")
96
+
97
+ if curr_loss < best_loss:
98
+ print(f"\tdecreased {best_loss:.6f}-->{curr_loss:.6f}")
99
+ best_loss = curr_loss
100
+ periods_no_improve = 0 # reset
101
+ best_lora_nf = copy.deepcopy(lora_nf)
102
+ else:
103
+ # early stopping if no improvement for several epochs
104
+ periods_no_improve += 1
105
+ if periods_no_improve >= convergence_patience:
106
+ print(f"Early stopping at step {i} with loss {curr_loss}")
107
+ break
108
+ prev_time = time.perf_counter()
109
+ if i > 0 and log_interval < 1000:
110
+ log_interval *= 10
111
+ if save_dir:
112
+ torch.save({
113
+ 'step': i,
114
+ 'model_state_dict': best_lora_nf.state_dict(),
115
+ 'loss': curr_loss,
116
+ }, f"{save_dir}/lora_nf_best.pt")
117
+ target_sampler.save_model_output(best_lora_nf, save_path=f"{save_dir}/lora_nf_best")
118
+
119
+ return best_lora_nf.get_lora_weights(), best_lora_nf.as_sequential()
120
+
121
+ def train_base_model(
122
+ base_nf: nn.Sequential,
123
+ data_sampler: DataSampler,
124
+ loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
125
+ learning_rate=1e-4,
126
+ batch_size=2**18,
127
+ max_n_steps=100000,
128
+ lr_scheduler_warmup_steps=0,
129
+ log_interval=100,
130
+ convergence_patience=15,
131
+ save_dir="",
132
+ ):
133
+ """
134
+ Trains the given base neural field to regress samples from data_sampler.
135
+
136
+ Returns a copy of base_nf at the best training iteration
137
+ """
138
+ device = get_device(base_nf)
139
+ trainable_params = sum(p.numel() for p in base_nf.parameters() if p.requires_grad)
140
+ print(f"training base model on {device}")
141
+ print(f"# trainable base model parameters: {trainable_params}")
142
+
143
+ optimizer = torch.optim.Adam(base_nf.parameters(), lr=learning_rate) # only fine-tune mlp
144
+ # use learning rate scheduler if specified
145
+ scheduler = None
146
+ if lr_scheduler_warmup_steps > 0:
147
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=lr_scheduler_warmup_steps)
148
+ best_loss = float('inf')
149
+ best_model = copy.deepcopy(base_nf)
150
+ prev_time = time.perf_counter()
151
+ periods_no_improve = 0 # for early stopping
152
+
153
+ if save_dir:
154
+ os.makedirs(save_dir, exist_ok=True)
155
+
156
+ for i in range(max_n_steps):
157
+ batch_inputs, batch_targets = data_sampler.sample_batch(batch_size, device) # [B,in], [B,out]
158
+ output = base_nf(batch_inputs.to(device))
159
+ loss = loss_fn(output, batch_targets.to(device))
160
+
161
+ optimizer.zero_grad()
162
+ loss.backward()
163
+ optimizer.step()
164
+ if scheduler:
165
+ scheduler.step()
166
+
167
+ if i % log_interval == 0: # reconstruct output
168
+ curr_loss = loss.item()
169
+ elapsed_time = time.perf_counter() - prev_time
170
+ print(f"Step#{i}: loss={curr_loss:.7f} time={int(elapsed_time)}[s]")
171
+ if save_dir:
172
+ with torch.no_grad():
173
+ data_sampler.save_model_output(base_nf, save_path=f"{save_dir}/base_nf_step_{i:05d}")
174
+
175
+ if curr_loss < best_loss:
176
+ print(f"\tdecreased {best_loss:.6f}-->{curr_loss:.6f}")
177
+ best_loss = curr_loss
178
+ periods_no_improve = 0 # reset
179
+ best_model = copy.deepcopy(base_nf)
180
+ else:
181
+ # early stopping if no improvement for several epochs
182
+ periods_no_improve += 1
183
+ if periods_no_improve >= convergence_patience:
184
+ print(f"Early stopping at step {i} with loss {curr_loss}")
185
+ break
186
+ prev_time = time.perf_counter()
187
+ if i > 0 and log_interval < 1000:
188
+ log_interval *= 10
189
+ if save_dir:
190
+ data_sampler.save_model_output(best_model, save_path=f"{save_dir}/base_nf_best")
191
+ torch.save({
192
+ 'step': i,
193
+ 'model_state_dict': best_model.state_dict(),
194
+ 'loss': curr_loss,
195
+ }, f"{save_dir}/base_nf_best.pt")
196
+
197
+ return best_model
198
+
@@ -0,0 +1,269 @@
1
+ import os
2
+ import mcubes
3
+ import numpy as np
4
+ import packaging
5
+ import torch
6
+ import trimesh
7
+
8
+ def get_device(module):
9
+ try:
10
+ return next(module.parameters()).device
11
+ except StopIteration:
12
+ return next(module.buffers()).device
13
+
14
+ def check_shape_equality(*images): # borrowed from scikit-image
15
+ """Check that all images have the same shape"""
16
+ image0 = images[0]
17
+ if not all(image0.shape == image.shape for image in images[1:]):
18
+ raise ValueError('Input images must have the same dimensions.')
19
+ return
20
+
21
+
22
+ #################
23
+ ## borrowed from scikit-image for psnr metric
24
+ dtype_range = {
25
+ bool: (False, True),
26
+ np.bool_: (False, True),
27
+ float: (-1, 1),
28
+ np.float16: (-1, 1),
29
+ np.float32: (-1, 1),
30
+ np.float64: (-1, 1),
31
+ }
32
+ new_float_type = {
33
+ # preserved types
34
+ np.float32().dtype.char: np.float32,
35
+ np.float64().dtype.char: np.float64,
36
+ np.complex64().dtype.char: np.complex64,
37
+ np.complex128().dtype.char: np.complex128,
38
+ # altered types
39
+ np.float16().dtype.char: np.float32,
40
+ 'g': np.float64, # np.float128 ; doesn't exist on windows
41
+ 'G': np.complex128, # np.complex256 ; doesn't exist on windows
42
+ }
43
+ def _supported_float_type(input_dtype, allow_complex=False):
44
+ """Return an appropriate floating-point dtype for a given dtype.
45
+
46
+ float32, float64, complex64, complex128 are preserved.
47
+ float16 is promoted to float32.
48
+ complex256 is demoted to complex128.
49
+ Other types are cast to float64.
50
+
51
+ Parameters
52
+ ----------
53
+ input_dtype : np.dtype or tuple of np.dtype
54
+ The input dtype. If a tuple of multiple dtypes is provided, each
55
+ dtype is first converted to a supported floating point type and the
56
+ final dtype is then determined by applying `np.result_type` on the
57
+ sequence of supported floating point types.
58
+ allow_complex : bool, optional
59
+ If False, raise a ValueError on complex-valued inputs.
60
+
61
+ Returns
62
+ -------
63
+ float_type : dtype
64
+ Floating-point dtype for the image.
65
+ """
66
+ if isinstance(input_dtype, tuple):
67
+ return np.result_type(*(_supported_float_type(d) for d in input_dtype))
68
+ input_dtype = np.dtype(input_dtype)
69
+ if not allow_complex and input_dtype.kind == 'c':
70
+ raise ValueError("complex valued input is not supported")
71
+ return new_float_type.get(input_dtype.char, np.float64)
72
+ def _as_floats(image0, image1):
73
+ """
74
+ Promote im1, im2 to nearest appropriate floating point precision.
75
+ """
76
+ float_type = _supported_float_type((image0.dtype, image1.dtype))
77
+ image0 = np.asarray(image0, dtype=float_type)
78
+ image1 = np.asarray(image1, dtype=float_type)
79
+ return image0, image1
80
+ def mean_squared_error(image0, image1):
81
+ """
82
+ Compute the mean-squared error between two images.
83
+
84
+ Parameters
85
+ ----------
86
+ image0, image1 : ndarray
87
+ Images. Any dimensionality, must have same shape.
88
+
89
+ Returns
90
+ -------
91
+ mse : float
92
+ The mean-squared error (MSE) metric.
93
+
94
+ Notes
95
+ -----
96
+ .. versionchanged:: 0.16
97
+ This function was renamed from ``skimage.measure.compare_mse`` to
98
+ ``skimage.metrics.mean_squared_error``.
99
+
100
+ """
101
+ check_shape_equality(image0, image1)
102
+ image0, image1 = _as_floats(image0, image1)
103
+ return np.mean((image0 - image1) ** 2, dtype=np.float64)
104
+
105
+ def peak_signal_noise_ratio(image_true, image_test, *, data_range=None):
106
+ """
107
+ Compute the peak signal to noise ratio (PSNR) for an image.
108
+
109
+ Parameters
110
+ ----------
111
+ image_true : ndarray
112
+ Ground-truth image, same shape as im_test.
113
+ image_test : ndarray
114
+ Test image.
115
+ data_range : int, optional
116
+ The data range of the input image (distance between minimum and
117
+ maximum possible values). By default, this is estimated from the image
118
+ data-type.
119
+
120
+ Returns
121
+ -------
122
+ psnr : float
123
+ The PSNR metric.
124
+
125
+ Notes
126
+ -----
127
+ .. versionchanged:: 0.16
128
+ This function was renamed from ``skimage.measure.compare_psnr`` to
129
+ ``skimage.metrics.peak_signal_noise_ratio``.
130
+
131
+ References
132
+ ----------
133
+ .. [1] https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
134
+
135
+ """
136
+ check_shape_equality(image_true, image_test)
137
+
138
+ if data_range is None:
139
+ if image_true.dtype != image_test.dtype:
140
+ print(
141
+ "Inputs have mismatched dtype. Setting data_range based on "
142
+ "image_true."
143
+ )
144
+ dmin, dmax = dtype_range[image_true.dtype.type]
145
+ true_min, true_max = np.min(image_true), np.max(image_true)
146
+ if true_max > dmax or true_min < dmin:
147
+ raise ValueError(
148
+ "image_true has intensity values outside the range expected "
149
+ "for its data type. Please manually specify the data_range."
150
+ )
151
+ if true_min >= 0:
152
+ # most common case (255 for uint8, 1 for float)
153
+ data_range = dmax
154
+ else:
155
+ data_range = dmax - dmin
156
+
157
+ image_true, image_test = _as_floats(image_true, image_test)
158
+
159
+ err = mean_squared_error(image_true, image_test)
160
+ data_range = float(data_range) # prevent overflow for small integer types
161
+ return 10 * np.log10((data_range**2) / err)
162
+ #################
163
+ #################
164
+
165
+
166
+ def measure_gpu_memory(device):
167
+ if torch.cuda.is_available():
168
+ allocated_memory = torch.cuda.memory_allocated(device)
169
+ max_allocated_memory = torch.cuda.max_memory_allocated(device)
170
+ print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
171
+ # print(f"Max allocated memory: {max_allocated_memory / 1024**2:.2f} MB")
172
+
173
+
174
+ def get_model_checkpoint_size_mb(checkpoint_path, return_param_count=False):
175
+ '''
176
+ checkpoint_path: full path containing file extension.
177
+ '''
178
+ checkpoint = torch.load(checkpoint_path, weights_only=True)["model_state_dict"]
179
+ size_model = 0
180
+ # for k,v in checkpoint.items():
181
+ # print(k, v.shape)
182
+ param_count = 0
183
+ for param in checkpoint.values():
184
+ # if param.is_floating_point():
185
+ size_model += param.nelement() * param.element_size()
186
+ param_count += param.nelement()
187
+ # else:
188
+ # size_model += param.numel() * torch.iinfo(param.dtype).bits
189
+ print(f"\tmodel size: {(size_model / 1e6):.3f} MB")
190
+ print(f"\tnum params: {param_count}")
191
+ if return_param_count:
192
+ return size_model / 1e6, param_count
193
+ return size_model / 1e6
194
+
195
+
196
+ def mean_relative_l2(pred, target, eps=0.01):
197
+ loss = (pred - target.to(pred.dtype))**2 / (pred.detach()**2 + eps)
198
+ return loss.mean()
199
+
200
+ def mape_loss(pred, target, reduction='mean'):
201
+ # pred, target: [B, 1], torch tenspr
202
+ difference = (pred - target).abs()
203
+ scale = 1 / (target.abs() + 1e-2)
204
+ loss = difference * scale
205
+
206
+ if reduction == 'mean':
207
+ loss = loss.mean()
208
+
209
+ return loss
210
+
211
+
212
+ ### marching cubes helpers from torch-ngp ###
213
+ def custom_meshgrid(*args):
214
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
215
+ if packaging.version.parse(torch.__version__) < packaging.version.parse('1.10'):
216
+ return torch.meshgrid(*args)
217
+ else:
218
+ return torch.meshgrid(*args, indexing='ij')
219
+
220
+ def extract_fields(bound_min, bound_max, resolution, query_func, device=torch.device("cpu")):
221
+ N = 64
222
+ X = torch.linspace(bound_min[0], bound_max[0], resolution).to(device).split(N)
223
+ Y = torch.linspace(bound_min[1], bound_max[1], resolution).to(device).split(N)
224
+ Z = torch.linspace(bound_min[2], bound_max[2], resolution).to(device).split(N)
225
+
226
+ u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
227
+ with torch.no_grad():
228
+ for xi, xs in enumerate(X):
229
+ for yi, ys in enumerate(Y):
230
+ for zi, zs in enumerate(Z):
231
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
232
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3]
233
+ val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [N, 1] --> [x, y, z]
234
+ u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
235
+ return u
236
+
237
+ def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
238
+ u = extract_fields(bound_min, bound_max, resolution, query_func)
239
+ vertices, triangles = mcubes.marching_cubes(u, threshold)
240
+
241
+ b_max_np = bound_max.detach().cpu().numpy()
242
+ b_min_np = bound_min.detach().cpu().numpy()
243
+
244
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
245
+ return vertices, triangles
246
+
247
+
248
+
249
+ def save_mesh(save_path, sdf_model, device, resolution=256):
250
+ print(f"==> Saving mesh to {save_path}")
251
+
252
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
253
+
254
+ def query_func(pts):
255
+ pts = pts.to(device)
256
+ with torch.no_grad():
257
+ # with torch.cuda.amp.autocast(enabled=False):
258
+ sdfs = sdf_model(pts)
259
+ return sdfs
260
+
261
+ bounds_min = torch.FloatTensor([-1, -1, -1])
262
+ bounds_max = torch.FloatTensor([1, 1, 1])
263
+
264
+ vertices, triangles = extract_geometry(bounds_min, bounds_max, resolution=resolution, threshold=0, query_func=query_func)
265
+
266
+ mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
267
+ mesh.export(save_path)
268
+
269
+ print(f"==> Finished saving mesh.")
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.4
2
+ Name: lora-nf
3
+ Version: 0.1.1
4
+ Summary: Low-Rank Adaptation of Neural Fields
5
+ Home-page: https://github.com/dinhanhtruong/LoRA-NF
6
+ Author: Anh Truong
7
+ License: MIT License
8
+ Classifier: Programming Language :: Python :: 3.12
9
+ Requires-Dist: torch
10
+ Requires-Dist: numpy
11
+ Requires-Dist: imageio
12
+ Requires-Dist: trimesh
13
+ Requires-Dist: pysdf
14
+ Requires-Dist: PyMCubes
15
+ Requires-Dist: packaging
16
+ Dynamic: author
17
+ Dynamic: classifier
18
+ Dynamic: home-page
19
+ Dynamic: license
20
+ Dynamic: requires-dist
21
+ Dynamic: summary
@@ -0,0 +1,13 @@
1
+ README.md
2
+ setup.py
3
+ lora_nf/__init__.py
4
+ lora_nf/common.py
5
+ lora_nf/data_samplers.py
6
+ lora_nf/torch_modules.py
7
+ lora_nf/train_lora.py
8
+ lora_nf/util.py
9
+ lora_nf.egg-info/PKG-INFO
10
+ lora_nf.egg-info/SOURCES.txt
11
+ lora_nf.egg-info/dependency_links.txt
12
+ lora_nf.egg-info/requires.txt
13
+ lora_nf.egg-info/top_level.txt
@@ -0,0 +1,7 @@
1
+ torch
2
+ numpy
3
+ imageio
4
+ trimesh
5
+ pysdf
6
+ PyMCubes
7
+ packaging
@@ -0,0 +1 @@
1
+ lora_nf
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
lora_nf-0.1.1/setup.py ADDED
@@ -0,0 +1,21 @@
1
+ from setuptools import setup
2
+
3
+ setup(
4
+ name='lora-nf',
5
+ version='0.1.1',
6
+ description='Low-Rank Adaptation of Neural Fields',
7
+ url='https://github.com/dinhanhtruong/LoRA-NF',
8
+ author='Anh Truong',
9
+ license='MIT License',
10
+ packages=['lora_nf'],
11
+ install_requires=['torch',
12
+ 'numpy',
13
+ 'imageio',
14
+ 'trimesh',
15
+ 'pysdf',
16
+ 'PyMCubes',
17
+ 'packaging'],
18
+ classifiers=[
19
+ 'Programming Language :: Python :: 3.12',
20
+ ],
21
+ )