lora-nf 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lora_nf-0.1.1/PKG-INFO +21 -0
- lora_nf-0.1.1/README.md +2 -0
- lora_nf-0.1.1/lora_nf/__init__.py +1 -0
- lora_nf-0.1.1/lora_nf/common.py +151 -0
- lora_nf-0.1.1/lora_nf/data_samplers.py +169 -0
- lora_nf-0.1.1/lora_nf/torch_modules.py +122 -0
- lora_nf-0.1.1/lora_nf/train_lora.py +198 -0
- lora_nf-0.1.1/lora_nf/util.py +269 -0
- lora_nf-0.1.1/lora_nf.egg-info/PKG-INFO +21 -0
- lora_nf-0.1.1/lora_nf.egg-info/SOURCES.txt +13 -0
- lora_nf-0.1.1/lora_nf.egg-info/dependency_links.txt +1 -0
- lora_nf-0.1.1/lora_nf.egg-info/requires.txt +7 -0
- lora_nf-0.1.1/lora_nf.egg-info/top_level.txt +1 -0
- lora_nf-0.1.1/setup.cfg +4 -0
- lora_nf-0.1.1/setup.py +21 -0
lora_nf-0.1.1/PKG-INFO
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: lora-nf
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Low-Rank Adaptation of Neural Fields
|
|
5
|
+
Home-page: https://github.com/dinhanhtruong/LoRA-NF
|
|
6
|
+
Author: Anh Truong
|
|
7
|
+
License: MIT License
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
9
|
+
Requires-Dist: torch
|
|
10
|
+
Requires-Dist: numpy
|
|
11
|
+
Requires-Dist: imageio
|
|
12
|
+
Requires-Dist: trimesh
|
|
13
|
+
Requires-Dist: pysdf
|
|
14
|
+
Requires-Dist: PyMCubes
|
|
15
|
+
Requires-Dist: packaging
|
|
16
|
+
Dynamic: author
|
|
17
|
+
Dynamic: classifier
|
|
18
|
+
Dynamic: home-page
|
|
19
|
+
Dynamic: license
|
|
20
|
+
Dynamic: requires-dist
|
|
21
|
+
Dynamic: summary
|
lora_nf-0.1.1/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.1"
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
6
|
+
# provided that the following conditions are met:
|
|
7
|
+
# * Redistributions of source code must retain the above copyright notice, this list of
|
|
8
|
+
# conditions and the following disclaimer.
|
|
9
|
+
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
10
|
+
# conditions and the following disclaimer in the documentation and/or other materials
|
|
11
|
+
# provided with the distribution.
|
|
12
|
+
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
13
|
+
# to endorse or promote products derived from this software without specific prior written
|
|
14
|
+
# permission.
|
|
15
|
+
#
|
|
16
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
17
|
+
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
19
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
20
|
+
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
21
|
+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
22
|
+
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
23
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
24
|
+
|
|
25
|
+
import imageio
|
|
26
|
+
import numpy as np
|
|
27
|
+
import os
|
|
28
|
+
import struct
|
|
29
|
+
|
|
30
|
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
31
|
+
|
|
32
|
+
def mse2psnr(x):
|
|
33
|
+
return -10.*np.log(x)/np.log(10.)
|
|
34
|
+
|
|
35
|
+
def write_image_imageio(img_file, img, quality):
|
|
36
|
+
img = (np.clip(img, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) #
|
|
37
|
+
kwargs = {}
|
|
38
|
+
if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]:
|
|
39
|
+
if img.ndim >= 3 and img.shape[2] > 3:
|
|
40
|
+
img = img[:,:,:3]
|
|
41
|
+
if img.ndim >= 3 and img.shape[2] == 1: # grayscale
|
|
42
|
+
img = np.squeeze(img)
|
|
43
|
+
kwargs["quality"] = quality
|
|
44
|
+
kwargs["subsampling"] = 0
|
|
45
|
+
imageio.imwrite(img_file, img, **kwargs)
|
|
46
|
+
|
|
47
|
+
def read_image_imageio(img_file):
|
|
48
|
+
img = imageio.imread(img_file)
|
|
49
|
+
img = np.asarray(img).astype(np.float32)
|
|
50
|
+
if len(img.shape) == 2:
|
|
51
|
+
img = img[:,:,np.newaxis]
|
|
52
|
+
return img / 255.0
|
|
53
|
+
|
|
54
|
+
def srgb_to_linear(img):
|
|
55
|
+
limit = 0.04045
|
|
56
|
+
return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)
|
|
57
|
+
|
|
58
|
+
def linear_to_srgb(img):
|
|
59
|
+
limit = 0.0031308
|
|
60
|
+
return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)
|
|
61
|
+
|
|
62
|
+
def read_image(file):
|
|
63
|
+
if os.path.splitext(file)[1] == ".bin":
|
|
64
|
+
with open(file, "rb") as f:
|
|
65
|
+
bytes = f.read()
|
|
66
|
+
h, w = struct.unpack("ii", bytes[:8])
|
|
67
|
+
img = np.frombuffer(bytes, dtype=np.float16, count=h*w*4, offset=8).astype(np.float32).reshape([h, w, 4])
|
|
68
|
+
else:
|
|
69
|
+
img = read_image_imageio(file)
|
|
70
|
+
if img.shape[2] == 4:
|
|
71
|
+
img[...,0:3] = srgb_to_linear(img[...,0:3])
|
|
72
|
+
# Premultiply alpha
|
|
73
|
+
img[...,0:3] *= img[...,3:4]
|
|
74
|
+
else:
|
|
75
|
+
img = srgb_to_linear(img)
|
|
76
|
+
return img
|
|
77
|
+
|
|
78
|
+
def write_image(file, img, quality=100):
|
|
79
|
+
if os.path.splitext(file)[1] == ".bin":
|
|
80
|
+
if img.shape[2] < 4:
|
|
81
|
+
img = np.dstack((img, np.ones([img.shape[0], img.shape[1], 4 - img.shape[2]])))
|
|
82
|
+
with open(file, "wb") as f:
|
|
83
|
+
f.write(struct.pack("ii", img.shape[0], img.shape[1]))
|
|
84
|
+
f.write(img.astype(np.float16).tobytes())
|
|
85
|
+
else:
|
|
86
|
+
if img.shape[2] == 4:
|
|
87
|
+
img = np.copy(img)
|
|
88
|
+
# Unmultiply alpha
|
|
89
|
+
img[...,0:3] = np.divide(img[...,0:3], img[...,3:4], out=np.zeros_like(img[...,0:3]), where=img[...,3:4] != 0)
|
|
90
|
+
img[...,0:3] = linear_to_srgb(img[...,0:3])
|
|
91
|
+
else:
|
|
92
|
+
img = linear_to_srgb(img)
|
|
93
|
+
write_image_imageio(file, img, quality)
|
|
94
|
+
|
|
95
|
+
def trim(error, skip=0.000001):
|
|
96
|
+
error = np.sort(error.flatten())
|
|
97
|
+
size = error.size
|
|
98
|
+
skip = int(skip * size)
|
|
99
|
+
return error[skip:size-skip].mean()
|
|
100
|
+
|
|
101
|
+
def luminance(a):
|
|
102
|
+
a = np.maximum(0, a)**0.4545454545
|
|
103
|
+
return 0.2126 * a[:,:,0] + 0.7152 * a[:,:,1] + 0.0722 * a[:,:,2]
|
|
104
|
+
|
|
105
|
+
def L1(img, ref):
|
|
106
|
+
return np.abs(img - ref)
|
|
107
|
+
|
|
108
|
+
def APE(img, ref):
|
|
109
|
+
return L1(img, ref) / (1e-2 + ref)
|
|
110
|
+
|
|
111
|
+
def SAPE(img, ref):
|
|
112
|
+
return L1(img, ref) / (1e-2 + (ref + img) / 2.)
|
|
113
|
+
|
|
114
|
+
def L2(img, ref):
|
|
115
|
+
return (img - ref)**2
|
|
116
|
+
|
|
117
|
+
def RSE(img, ref):
|
|
118
|
+
return L2(img, ref) / (1e-2 + ref**2)
|
|
119
|
+
|
|
120
|
+
def rgb_mean(img):
|
|
121
|
+
return np.mean(img, axis=2)
|
|
122
|
+
|
|
123
|
+
def compute_error_img(metric, img, ref):
|
|
124
|
+
img[np.logical_not(np.isfinite(img))] = 0
|
|
125
|
+
img = np.maximum(img, 0.)
|
|
126
|
+
if metric == "MAE":
|
|
127
|
+
return L1(img, ref)
|
|
128
|
+
elif metric == "MAPE":
|
|
129
|
+
return APE(img, ref)
|
|
130
|
+
elif metric == "SMAPE":
|
|
131
|
+
return SAPE(img, ref)
|
|
132
|
+
elif metric == "MSE":
|
|
133
|
+
return L2(img, ref)
|
|
134
|
+
elif metric == "MScE":
|
|
135
|
+
return L2(np.clip(img, 0.0, 1.0), np.clip(ref, 0.0, 1.0))
|
|
136
|
+
elif metric == "MRSE":
|
|
137
|
+
return RSE(img, ref)
|
|
138
|
+
elif metric == "MtRSE":
|
|
139
|
+
return trim(RSE(img, ref))
|
|
140
|
+
elif metric == "MRScE":
|
|
141
|
+
return RSE(np.clip(img, 0, 100), np.clip(ref, 0, 100))
|
|
142
|
+
|
|
143
|
+
raise ValueError(f"Unknown metric: {metric}.")
|
|
144
|
+
|
|
145
|
+
def compute_error(metric, img, ref):
|
|
146
|
+
metric_map = compute_error_img(metric, img, ref)
|
|
147
|
+
metric_map[np.logical_not(np.isfinite(metric_map))] = 0
|
|
148
|
+
if len(metric_map.shape) == 3:
|
|
149
|
+
metric_map = np.mean(metric_map, axis=2)
|
|
150
|
+
mean = np.mean(metric_map)
|
|
151
|
+
return mean
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from .common import read_image, write_image
|
|
5
|
+
import trimesh
|
|
6
|
+
from .util import save_mesh
|
|
7
|
+
import pysdf
|
|
8
|
+
|
|
9
|
+
class DataSampler:
|
|
10
|
+
def sample_batch(self, n_samples: int, device: torch.device):
|
|
11
|
+
"""
|
|
12
|
+
samples input positions and corresponding field values
|
|
13
|
+
|
|
14
|
+
Args
|
|
15
|
+
- n_samples: number of samples
|
|
16
|
+
- device: torch device (e.g. "cpu", "cuda:0") on which the sampled data will live
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
- inputs: sample positions. torch tensor of shape [n_samples, data_input_dim], e.g., [n_samples, 2] for image xy positions
|
|
20
|
+
- targets: values of the field sampled at inputs. torch tensor of shape [n_samples, data_output_dim], e.g., [n_samples, 3] for image RGB values
|
|
21
|
+
"""
|
|
22
|
+
raise NotImplementedError
|
|
23
|
+
def save_model_output(self, model: torch.nn.Module, save_path: str):
|
|
24
|
+
"""
|
|
25
|
+
saves a reconstruction of the given neural field at the given path (without extension). assumes that the directory exists
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
- model: a neural field nn.Module()
|
|
29
|
+
|
|
30
|
+
Returns nothing
|
|
31
|
+
"""
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
class Image(DataSampler, torch.nn.Module):
|
|
35
|
+
def __init__(self, filename: str, device: torch.device):
|
|
36
|
+
"""
|
|
37
|
+
Args
|
|
38
|
+
- filename: path to image
|
|
39
|
+
- device: torch.device
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.data = read_image(filename)
|
|
43
|
+
# remove alpha channel
|
|
44
|
+
if self.data.ndim > 2 and self.data.shape[2] == 4:
|
|
45
|
+
self.data = self.data[:,:,:3] # keep RGB
|
|
46
|
+
self.orig_data_npy = self.data
|
|
47
|
+
self.shape = self.data.shape
|
|
48
|
+
self.data = torch.from_numpy(self.data).float().to(device)
|
|
49
|
+
self.device = device
|
|
50
|
+
|
|
51
|
+
# for model output reconsturction
|
|
52
|
+
resolution = self.data.shape[0:2]
|
|
53
|
+
# n_pixels = resolution[0] * resolution[1]
|
|
54
|
+
half_dx = 0.5 / resolution[0] # half pixel size
|
|
55
|
+
half_dy = 0.5 / resolution[1]
|
|
56
|
+
xs = torch.linspace(half_dx, 1-half_dx, resolution[0])
|
|
57
|
+
ys = torch.linspace(half_dy, 1-half_dy, resolution[1])
|
|
58
|
+
xv, yv = torch.meshgrid([xs, ys], indexing="ij")
|
|
59
|
+
self.img_shape = resolution + torch.Size([self.data.shape[2]])
|
|
60
|
+
self.xy = torch.stack((yv.flatten(), xv.flatten())).t().to(device)
|
|
61
|
+
|
|
62
|
+
def forward(self, xs, interpolate=True):
|
|
63
|
+
with torch.no_grad():
|
|
64
|
+
# Bilinearly filtered lookup from the image. Not super fast,
|
|
65
|
+
# but less than ~20% of the overall runtime of this example.
|
|
66
|
+
shape = self.shape
|
|
67
|
+
|
|
68
|
+
xs = xs * torch.tensor([shape[1], shape[0]], device=xs.device).float()
|
|
69
|
+
indices = xs.long()
|
|
70
|
+
|
|
71
|
+
x0 = indices[:, 0].clamp(min=0, max=shape[1]-1)
|
|
72
|
+
y0 = indices[:, 1].clamp(min=0, max=shape[0]-1)
|
|
73
|
+
if interpolate:
|
|
74
|
+
lerp_weights = xs - indices.float()
|
|
75
|
+
x1 = (x0 + 1).clamp(max=shape[1]-1)
|
|
76
|
+
y1 = (y0 + 1).clamp(max=shape[0]-1)
|
|
77
|
+
|
|
78
|
+
return (
|
|
79
|
+
self.data[y0, x0] * (1.0 - lerp_weights[:,0:1]) * (1.0 - lerp_weights[:,1:2]) +
|
|
80
|
+
self.data[y0, x1] * lerp_weights[:,0:1] * (1.0 - lerp_weights[:,1:2]) +
|
|
81
|
+
self.data[y1, x0] * (1.0 - lerp_weights[:,0:1]) * lerp_weights[:,1:2] +
|
|
82
|
+
self.data[y1, x1] * lerp_weights[:,0:1] * lerp_weights[:,1:2]
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return (self.data[y0, x0]) # no interpolation
|
|
86
|
+
|
|
87
|
+
def sample_batch(self, n_samples, device):
|
|
88
|
+
assert device == self.device
|
|
89
|
+
input_xy = torch.rand([n_samples, 2], dtype=torch.float, device=device)
|
|
90
|
+
image_rgb = self.forward(input_xy)
|
|
91
|
+
return input_xy, image_rgb
|
|
92
|
+
|
|
93
|
+
def save_model_output(self, model, save_path):
|
|
94
|
+
write_image(f"{save_path}.png", model(self.xy).reshape(self.img_shape).clamp(0.0, 1.0).detach().cpu().numpy())
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class SDF(DataSampler):
|
|
98
|
+
def __init__(self, path: str, device: torch.device, num_samples=2**18, clip_sdf=None, transformation_save_dir=""):
|
|
99
|
+
super().__init__()
|
|
100
|
+
self.path = path
|
|
101
|
+
self.device = self.device
|
|
102
|
+
|
|
103
|
+
# load obj
|
|
104
|
+
self.mesh = trimesh.load(path, force='mesh')
|
|
105
|
+
|
|
106
|
+
# normalize to [-1, 1] (different from instant-sdf where is [0, 1]) via scaling and translation
|
|
107
|
+
vs = self.mesh.vertices
|
|
108
|
+
vmin = vs.min(0)
|
|
109
|
+
vmax = vs.max(0)
|
|
110
|
+
v_center = (vmin + vmax) / 2
|
|
111
|
+
v_scale = 2 / np.sqrt(np.sum((vmax - vmin) ** 2)) * 0.95
|
|
112
|
+
|
|
113
|
+
# TODO: save normalizing transformation of base model and apply same transformation to deformed model
|
|
114
|
+
transformation_path = f"{transformation_save_dir}/data_normalization_transformation.npz"
|
|
115
|
+
if not os.path.exists(transformation_path):
|
|
116
|
+
print("saving base mesh normalization transformation")
|
|
117
|
+
np.savez(transformation_path, v_scale=v_scale, v_center=v_center)
|
|
118
|
+
else:
|
|
119
|
+
print("######################")
|
|
120
|
+
print("####### TEMP DISABLED CONSISTENT SDF NORMALIZATION")
|
|
121
|
+
print("######################")
|
|
122
|
+
print("######################")
|
|
123
|
+
print("######################")
|
|
124
|
+
print("######################")
|
|
125
|
+
# print("using existing normalization transformation")
|
|
126
|
+
# loaded_transformation = np.load(transformation_path)
|
|
127
|
+
# v_scale = loaded_transformation["v_scale"]
|
|
128
|
+
# v_center = loaded_transformation["v_center"]
|
|
129
|
+
|
|
130
|
+
print("scale: ", v_scale)
|
|
131
|
+
print("center: ", v_center)
|
|
132
|
+
# apply transformation to verts
|
|
133
|
+
vs = (vs - v_center[None, :]) * v_scale
|
|
134
|
+
self.mesh.vertices = vs
|
|
135
|
+
|
|
136
|
+
print(f"[INFO] mesh verts & faces: {self.mesh.vertices.shape} & {self.mesh.faces.shape}")
|
|
137
|
+
|
|
138
|
+
if not self.mesh.is_watertight:
|
|
139
|
+
print(f"[WARN] mesh is not watertight! SDF maybe incorrect.")
|
|
140
|
+
#trimesh.Scene([self.mesh]).show()
|
|
141
|
+
|
|
142
|
+
self.sdf_fn = pysdf.SDF(self.mesh.vertices, self.mesh.faces)
|
|
143
|
+
|
|
144
|
+
self.num_samples = num_samples
|
|
145
|
+
assert self.num_samples % 8 == 0, "num_samples must be divisible by 8."
|
|
146
|
+
self.clip_sdf = clip_sdf
|
|
147
|
+
|
|
148
|
+
def sample_batch(self, n_samples, device):
|
|
149
|
+
assert device == self.device
|
|
150
|
+
# online sampling
|
|
151
|
+
sdfs = np.zeros((n_samples, 1))
|
|
152
|
+
# surface query points (7/8 points for surface and near-surface)
|
|
153
|
+
points_surface = self.mesh.sample(n_samples * 7 // 8)
|
|
154
|
+
|
|
155
|
+
# near-surface points
|
|
156
|
+
points_surface[n_samples // 2:] += 0.01 * np.random.randn(n_samples * 3 // 8, 3)
|
|
157
|
+
# random uniform points (1/8 of points)
|
|
158
|
+
points_uniform = np.random.rand(n_samples // 8, 3) * 2 - 1
|
|
159
|
+
points = np.concatenate([points_surface, points_uniform], axis=0).astype(np.float32)
|
|
160
|
+
sdfs[n_samples // 2:] = -self.sdf_fn(points[n_samples // 2:])[:,None].astype(np.float32)
|
|
161
|
+
|
|
162
|
+
# clip sdf
|
|
163
|
+
if self.clip_sdf is not None:
|
|
164
|
+
sdfs = sdfs.clip(-self.clip_sdf, self.clip_sdf)
|
|
165
|
+
|
|
166
|
+
return torch.tensor(points, device=device), torch.tensor(sdfs, device=device)
|
|
167
|
+
|
|
168
|
+
def save_model_output(self, model: torch.nn.Module, save_path: str):
|
|
169
|
+
save_mesh(save_path, model, self.device, resolution=256)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
class CustomFrequencyEncoding(nn.Module):
|
|
6
|
+
def __init__(self):
|
|
7
|
+
super(CustomFrequencyEncoding, self).__init__()
|
|
8
|
+
def get_encoding_output_dim(self, input_dim):
|
|
9
|
+
return self.forward(torch.zeros((1, input_dim))).shape[-1]
|
|
10
|
+
|
|
11
|
+
def forward( # from https://github.com/krrish94/nerf-pytorch/blob/master/nerf/nerf_helpers.py
|
|
12
|
+
self, tensor, num_encoding_functions=10, include_input=True, log_sampling=True
|
|
13
|
+
) -> torch.Tensor:
|
|
14
|
+
r"""Apply positional encoding to the input.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
tensor (torch.Tensor): Input tensor to be positionally encoded.
|
|
18
|
+
encoding_size (optional, int): Number of encoding functions used to compute
|
|
19
|
+
a positional encoding (default: 6).
|
|
20
|
+
include_input (optional, bool): Whether or not to include the input in the
|
|
21
|
+
positional encoding (default: True).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
(torch.Tensor): Positional encoding of the input tensor.
|
|
25
|
+
"""
|
|
26
|
+
encoding = [tensor] if include_input else []
|
|
27
|
+
frequency_bands = None
|
|
28
|
+
if log_sampling:
|
|
29
|
+
frequency_bands = 2.0 ** torch.linspace(
|
|
30
|
+
0.0,
|
|
31
|
+
num_encoding_functions - 1,
|
|
32
|
+
num_encoding_functions,
|
|
33
|
+
dtype=tensor.dtype,
|
|
34
|
+
device=tensor.device,
|
|
35
|
+
)
|
|
36
|
+
else:
|
|
37
|
+
frequency_bands = torch.linspace(
|
|
38
|
+
2.0 ** 0.0,
|
|
39
|
+
2.0 ** (num_encoding_functions - 1),
|
|
40
|
+
num_encoding_functions,
|
|
41
|
+
dtype=tensor.dtype,
|
|
42
|
+
device=tensor.device,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
for freq in frequency_bands:
|
|
46
|
+
for func in [torch.sin, torch.cos]:
|
|
47
|
+
encoding.append(func(tensor * freq))
|
|
48
|
+
|
|
49
|
+
# Special case, for no positional encoding
|
|
50
|
+
if len(encoding) == 1:
|
|
51
|
+
return encoding[0]
|
|
52
|
+
else:
|
|
53
|
+
return torch.cat(encoding, dim=-1)
|
|
54
|
+
|
|
55
|
+
class LoRA_MLP(nn.Module):
|
|
56
|
+
def __init__(self, base_mlp, rank):
|
|
57
|
+
'''
|
|
58
|
+
LoRA-augmented MLP
|
|
59
|
+
|
|
60
|
+
prev_frame_lora_paths: list of chronologically ordered filepaths to prev frames' LoraMLP weights
|
|
61
|
+
'''
|
|
62
|
+
super(LoRA_MLP, self).__init__()
|
|
63
|
+
|
|
64
|
+
# for each parent MLP linear layer, store a low-rank adaptor (LoRALinear layer).
|
|
65
|
+
# mimic parent MLP but replace linear with LoRALinear
|
|
66
|
+
self.sequential = nn.Sequential() # just LoRALinears
|
|
67
|
+
|
|
68
|
+
for base_layer in base_mlp:
|
|
69
|
+
if isinstance(base_layer, nn.Linear):
|
|
70
|
+
# add lora layer
|
|
71
|
+
self.sequential.append(LoRALinear(base_layer, r=rank))
|
|
72
|
+
else:
|
|
73
|
+
# keep parent layer (activation or positional encoding)
|
|
74
|
+
self.sequential.append(base_layer)
|
|
75
|
+
|
|
76
|
+
def as_sequential(self):
|
|
77
|
+
return self.sequential
|
|
78
|
+
def get_lora_weights(self):
|
|
79
|
+
"""
|
|
80
|
+
Returns a list of tensors of weights corresponding to the current LoRAs (in the same order as the base model's)
|
|
81
|
+
"""
|
|
82
|
+
lora_weights = []
|
|
83
|
+
for lora_mlp_layer in self.sequential:
|
|
84
|
+
if isinstance(lora_mlp_layer, LoRALinear):
|
|
85
|
+
lora_weights.append((lora_mlp_layer.A @ lora_mlp_layer.B).T) # [in_features, out_features]
|
|
86
|
+
return lora_weights
|
|
87
|
+
|
|
88
|
+
def forward(self, x):
|
|
89
|
+
return self.sequential(x)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class LoRALinear(nn.Module):
|
|
93
|
+
def __init__(self, base_linear, r=16, alpha=1, A=None):
|
|
94
|
+
'''
|
|
95
|
+
|
|
96
|
+
'''
|
|
97
|
+
super(LoRALinear, self).__init__()
|
|
98
|
+
assert isinstance(base_linear, nn.Linear) and r > 1
|
|
99
|
+
self.base_linear = base_linear
|
|
100
|
+
# freeze base; make sure only lora weights are trainable
|
|
101
|
+
self.base_linear.weight.requires_grad_(False)
|
|
102
|
+
if self.base_linear.bias is not None:
|
|
103
|
+
self.base_linear.bias.requires_grad_(False)
|
|
104
|
+
|
|
105
|
+
rank = min(r, base_linear.in_features, base_linear.out_features)
|
|
106
|
+
if A is not None:
|
|
107
|
+
# use specified A matrix
|
|
108
|
+
assert A.shape == torch.Size((base_linear.in_features, rank))
|
|
109
|
+
self.A = A
|
|
110
|
+
else:
|
|
111
|
+
self.A = nn.Parameter(torch.empty((base_linear.in_features, rank)))
|
|
112
|
+
nn.init.normal_(self.A, std=1/math.sqrt(base_linear.in_features)) #following pg4 of https://arxiv.org/pdf/2406.08447
|
|
113
|
+
self.B = nn.Parameter(torch.zeros(rank, base_linear.out_features))
|
|
114
|
+
# lora scale factor
|
|
115
|
+
self.scaling = alpha/r
|
|
116
|
+
|
|
117
|
+
def forward(self, x):
|
|
118
|
+
return self.base_linear(x) + torch.linalg.multi_dot((x, self.A, self.B))*self.scaling
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def extract_linear_layers(module):
|
|
122
|
+
return [m for m in module.modules() if isinstance(m, nn.Linear)]
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import numpy as np
|
|
4
|
+
import time
|
|
5
|
+
import torch
|
|
6
|
+
import copy
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from .torch_modules import LoRA_MLP, extract_linear_layers
|
|
9
|
+
from .util import get_device
|
|
10
|
+
from .data_samplers import DataSampler
|
|
11
|
+
from typing import Callable
|
|
12
|
+
|
|
13
|
+
torch.manual_seed(0)
|
|
14
|
+
np.random.seed(0)
|
|
15
|
+
random.seed(0)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def train_lora_regression(
|
|
19
|
+
base_nf: nn.Sequential,
|
|
20
|
+
target_sampler: DataSampler,
|
|
21
|
+
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
22
|
+
lora_rank: int,
|
|
23
|
+
learning_rate=5e-3,
|
|
24
|
+
batch_size=2**18,
|
|
25
|
+
max_n_steps=30000,
|
|
26
|
+
lr_scheduler_warmup_steps=7000,
|
|
27
|
+
log_interval=100,
|
|
28
|
+
convergence_patience=15,
|
|
29
|
+
save_dir="",
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Trains LoRAs for every linear layer of the given base neural field. Returns a new neural field with a LoRA applied and (optionally) the weights of each LoRA
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
- base_nf: MLP implemented as a nn.Sequential (containing nn.Linear layers, activation functions, optional input positional encoding)
|
|
36
|
+
- target_sampler: a DataSampler that must be compatible with base_nf (i.e. input/target dimensions must match)
|
|
37
|
+
- loss_fn: callable function (output,target) |-> loss scalar
|
|
38
|
+
- lora_rank: Desired maximum rank of each LoRA
|
|
39
|
+
- learning_rate: for ADAM optimizer
|
|
40
|
+
- batch_size: num samples per step
|
|
41
|
+
- max_n_steps: max number of training steps
|
|
42
|
+
- lr_scheduler_warmup_steps: number of warmup steps for the learning rate scheduler. No warmup if 0.
|
|
43
|
+
- convergence_patience: number of log_intervals of no improvement after which training is terminated early
|
|
44
|
+
- save_dir: if provided, then weights and output reconstructions will be saved there
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
- lora_weights: LoRA weight tensors, one per linear layer of the base model. Returns the iterate with the lowest loss.
|
|
48
|
+
- lora_nf: nn.Sequential neural field with LoRA applied to every linear layer. Returns the iterate with the lowest loss.
|
|
49
|
+
"""
|
|
50
|
+
# make sure that base_nf and target_sampler are compatible
|
|
51
|
+
device = get_device(base_nf)
|
|
52
|
+
linear_layers = extract_linear_layers(base_nf)
|
|
53
|
+
base_output_dim = linear_layers[-1].out_features
|
|
54
|
+
_, dummy_targets = target_sampler.sample_batch(batch_size, device)
|
|
55
|
+
assert dummy_targets.shape[-1] == base_output_dim, "base_nf and target_sampler are incompatible: output size mismatch"
|
|
56
|
+
assert isinstance(base_nf, nn.Sequential)
|
|
57
|
+
|
|
58
|
+
# set up lora neural field
|
|
59
|
+
lora_nf = LoRA_MLP(base_nf, lora_rank)
|
|
60
|
+
lora_nf.to(device)
|
|
61
|
+
print(f"training LoRA on {device}")
|
|
62
|
+
trainable_params = sum(p.numel() for p in lora_nf.parameters() if p.requires_grad)
|
|
63
|
+
print(f"# trainable LoRA parameters: {trainable_params}")
|
|
64
|
+
|
|
65
|
+
optimizer = torch.optim.Adam(lora_nf.parameters(), lr=learning_rate) # only fine-tune mlp
|
|
66
|
+
# use learning rate scheduler if specified
|
|
67
|
+
scheduler = None
|
|
68
|
+
if lr_scheduler_warmup_steps > 0:
|
|
69
|
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=lr_scheduler_warmup_steps)
|
|
70
|
+
# for logging
|
|
71
|
+
best_loss = float('inf')
|
|
72
|
+
best_lora_nf = copy.deepcopy(lora_nf)
|
|
73
|
+
prev_time = time.perf_counter()
|
|
74
|
+
periods_no_improve = 0 # for early stopping
|
|
75
|
+
if save_dir:
|
|
76
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
77
|
+
|
|
78
|
+
for i in range(max_n_steps):
|
|
79
|
+
batch_inputs, batch_targets = target_sampler.sample_batch(batch_size, device) # [B,in], [B,out]
|
|
80
|
+
output = lora_nf(batch_inputs.to(device))
|
|
81
|
+
loss = loss_fn(output, batch_targets.to(device))
|
|
82
|
+
|
|
83
|
+
optimizer.zero_grad()
|
|
84
|
+
loss.backward()
|
|
85
|
+
optimizer.step()
|
|
86
|
+
if scheduler:
|
|
87
|
+
scheduler.step()
|
|
88
|
+
|
|
89
|
+
if i % log_interval == 0: # reconstruct output
|
|
90
|
+
curr_loss = loss.item()
|
|
91
|
+
elapsed_time = time.perf_counter() - prev_time
|
|
92
|
+
print(f"Step#{i}: loss={curr_loss:.7f} time={int(elapsed_time)}[s]")
|
|
93
|
+
if save_dir:
|
|
94
|
+
with torch.no_grad():
|
|
95
|
+
target_sampler.save_model_output(lora_nf, save_path=f"{save_dir}/lora_nf_step_{i:05d}")
|
|
96
|
+
|
|
97
|
+
if curr_loss < best_loss:
|
|
98
|
+
print(f"\tdecreased {best_loss:.6f}-->{curr_loss:.6f}")
|
|
99
|
+
best_loss = curr_loss
|
|
100
|
+
periods_no_improve = 0 # reset
|
|
101
|
+
best_lora_nf = copy.deepcopy(lora_nf)
|
|
102
|
+
else:
|
|
103
|
+
# early stopping if no improvement for several epochs
|
|
104
|
+
periods_no_improve += 1
|
|
105
|
+
if periods_no_improve >= convergence_patience:
|
|
106
|
+
print(f"Early stopping at step {i} with loss {curr_loss}")
|
|
107
|
+
break
|
|
108
|
+
prev_time = time.perf_counter()
|
|
109
|
+
if i > 0 and log_interval < 1000:
|
|
110
|
+
log_interval *= 10
|
|
111
|
+
if save_dir:
|
|
112
|
+
torch.save({
|
|
113
|
+
'step': i,
|
|
114
|
+
'model_state_dict': best_lora_nf.state_dict(),
|
|
115
|
+
'loss': curr_loss,
|
|
116
|
+
}, f"{save_dir}/lora_nf_best.pt")
|
|
117
|
+
target_sampler.save_model_output(best_lora_nf, save_path=f"{save_dir}/lora_nf_best")
|
|
118
|
+
|
|
119
|
+
return best_lora_nf.get_lora_weights(), best_lora_nf.as_sequential()
|
|
120
|
+
|
|
121
|
+
def train_base_model(
|
|
122
|
+
base_nf: nn.Sequential,
|
|
123
|
+
data_sampler: DataSampler,
|
|
124
|
+
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
125
|
+
learning_rate=1e-4,
|
|
126
|
+
batch_size=2**18,
|
|
127
|
+
max_n_steps=100000,
|
|
128
|
+
lr_scheduler_warmup_steps=0,
|
|
129
|
+
log_interval=100,
|
|
130
|
+
convergence_patience=15,
|
|
131
|
+
save_dir="",
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Trains the given base neural field to regress samples from data_sampler.
|
|
135
|
+
|
|
136
|
+
Returns a copy of base_nf at the best training iteration
|
|
137
|
+
"""
|
|
138
|
+
device = get_device(base_nf)
|
|
139
|
+
trainable_params = sum(p.numel() for p in base_nf.parameters() if p.requires_grad)
|
|
140
|
+
print(f"training base model on {device}")
|
|
141
|
+
print(f"# trainable base model parameters: {trainable_params}")
|
|
142
|
+
|
|
143
|
+
optimizer = torch.optim.Adam(base_nf.parameters(), lr=learning_rate) # only fine-tune mlp
|
|
144
|
+
# use learning rate scheduler if specified
|
|
145
|
+
scheduler = None
|
|
146
|
+
if lr_scheduler_warmup_steps > 0:
|
|
147
|
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=lr_scheduler_warmup_steps)
|
|
148
|
+
best_loss = float('inf')
|
|
149
|
+
best_model = copy.deepcopy(base_nf)
|
|
150
|
+
prev_time = time.perf_counter()
|
|
151
|
+
periods_no_improve = 0 # for early stopping
|
|
152
|
+
|
|
153
|
+
if save_dir:
|
|
154
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
155
|
+
|
|
156
|
+
for i in range(max_n_steps):
|
|
157
|
+
batch_inputs, batch_targets = data_sampler.sample_batch(batch_size, device) # [B,in], [B,out]
|
|
158
|
+
output = base_nf(batch_inputs.to(device))
|
|
159
|
+
loss = loss_fn(output, batch_targets.to(device))
|
|
160
|
+
|
|
161
|
+
optimizer.zero_grad()
|
|
162
|
+
loss.backward()
|
|
163
|
+
optimizer.step()
|
|
164
|
+
if scheduler:
|
|
165
|
+
scheduler.step()
|
|
166
|
+
|
|
167
|
+
if i % log_interval == 0: # reconstruct output
|
|
168
|
+
curr_loss = loss.item()
|
|
169
|
+
elapsed_time = time.perf_counter() - prev_time
|
|
170
|
+
print(f"Step#{i}: loss={curr_loss:.7f} time={int(elapsed_time)}[s]")
|
|
171
|
+
if save_dir:
|
|
172
|
+
with torch.no_grad():
|
|
173
|
+
data_sampler.save_model_output(base_nf, save_path=f"{save_dir}/base_nf_step_{i:05d}")
|
|
174
|
+
|
|
175
|
+
if curr_loss < best_loss:
|
|
176
|
+
print(f"\tdecreased {best_loss:.6f}-->{curr_loss:.6f}")
|
|
177
|
+
best_loss = curr_loss
|
|
178
|
+
periods_no_improve = 0 # reset
|
|
179
|
+
best_model = copy.deepcopy(base_nf)
|
|
180
|
+
else:
|
|
181
|
+
# early stopping if no improvement for several epochs
|
|
182
|
+
periods_no_improve += 1
|
|
183
|
+
if periods_no_improve >= convergence_patience:
|
|
184
|
+
print(f"Early stopping at step {i} with loss {curr_loss}")
|
|
185
|
+
break
|
|
186
|
+
prev_time = time.perf_counter()
|
|
187
|
+
if i > 0 and log_interval < 1000:
|
|
188
|
+
log_interval *= 10
|
|
189
|
+
if save_dir:
|
|
190
|
+
data_sampler.save_model_output(best_model, save_path=f"{save_dir}/base_nf_best")
|
|
191
|
+
torch.save({
|
|
192
|
+
'step': i,
|
|
193
|
+
'model_state_dict': best_model.state_dict(),
|
|
194
|
+
'loss': curr_loss,
|
|
195
|
+
}, f"{save_dir}/base_nf_best.pt")
|
|
196
|
+
|
|
197
|
+
return best_model
|
|
198
|
+
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import mcubes
|
|
3
|
+
import numpy as np
|
|
4
|
+
import packaging
|
|
5
|
+
import torch
|
|
6
|
+
import trimesh
|
|
7
|
+
|
|
8
|
+
def get_device(module):
|
|
9
|
+
try:
|
|
10
|
+
return next(module.parameters()).device
|
|
11
|
+
except StopIteration:
|
|
12
|
+
return next(module.buffers()).device
|
|
13
|
+
|
|
14
|
+
def check_shape_equality(*images): # borrowed from scikit-image
|
|
15
|
+
"""Check that all images have the same shape"""
|
|
16
|
+
image0 = images[0]
|
|
17
|
+
if not all(image0.shape == image.shape for image in images[1:]):
|
|
18
|
+
raise ValueError('Input images must have the same dimensions.')
|
|
19
|
+
return
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
#################
|
|
23
|
+
## borrowed from scikit-image for psnr metric
|
|
24
|
+
dtype_range = {
|
|
25
|
+
bool: (False, True),
|
|
26
|
+
np.bool_: (False, True),
|
|
27
|
+
float: (-1, 1),
|
|
28
|
+
np.float16: (-1, 1),
|
|
29
|
+
np.float32: (-1, 1),
|
|
30
|
+
np.float64: (-1, 1),
|
|
31
|
+
}
|
|
32
|
+
new_float_type = {
|
|
33
|
+
# preserved types
|
|
34
|
+
np.float32().dtype.char: np.float32,
|
|
35
|
+
np.float64().dtype.char: np.float64,
|
|
36
|
+
np.complex64().dtype.char: np.complex64,
|
|
37
|
+
np.complex128().dtype.char: np.complex128,
|
|
38
|
+
# altered types
|
|
39
|
+
np.float16().dtype.char: np.float32,
|
|
40
|
+
'g': np.float64, # np.float128 ; doesn't exist on windows
|
|
41
|
+
'G': np.complex128, # np.complex256 ; doesn't exist on windows
|
|
42
|
+
}
|
|
43
|
+
def _supported_float_type(input_dtype, allow_complex=False):
|
|
44
|
+
"""Return an appropriate floating-point dtype for a given dtype.
|
|
45
|
+
|
|
46
|
+
float32, float64, complex64, complex128 are preserved.
|
|
47
|
+
float16 is promoted to float32.
|
|
48
|
+
complex256 is demoted to complex128.
|
|
49
|
+
Other types are cast to float64.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
input_dtype : np.dtype or tuple of np.dtype
|
|
54
|
+
The input dtype. If a tuple of multiple dtypes is provided, each
|
|
55
|
+
dtype is first converted to a supported floating point type and the
|
|
56
|
+
final dtype is then determined by applying `np.result_type` on the
|
|
57
|
+
sequence of supported floating point types.
|
|
58
|
+
allow_complex : bool, optional
|
|
59
|
+
If False, raise a ValueError on complex-valued inputs.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
float_type : dtype
|
|
64
|
+
Floating-point dtype for the image.
|
|
65
|
+
"""
|
|
66
|
+
if isinstance(input_dtype, tuple):
|
|
67
|
+
return np.result_type(*(_supported_float_type(d) for d in input_dtype))
|
|
68
|
+
input_dtype = np.dtype(input_dtype)
|
|
69
|
+
if not allow_complex and input_dtype.kind == 'c':
|
|
70
|
+
raise ValueError("complex valued input is not supported")
|
|
71
|
+
return new_float_type.get(input_dtype.char, np.float64)
|
|
72
|
+
def _as_floats(image0, image1):
|
|
73
|
+
"""
|
|
74
|
+
Promote im1, im2 to nearest appropriate floating point precision.
|
|
75
|
+
"""
|
|
76
|
+
float_type = _supported_float_type((image0.dtype, image1.dtype))
|
|
77
|
+
image0 = np.asarray(image0, dtype=float_type)
|
|
78
|
+
image1 = np.asarray(image1, dtype=float_type)
|
|
79
|
+
return image0, image1
|
|
80
|
+
def mean_squared_error(image0, image1):
|
|
81
|
+
"""
|
|
82
|
+
Compute the mean-squared error between two images.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
image0, image1 : ndarray
|
|
87
|
+
Images. Any dimensionality, must have same shape.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
mse : float
|
|
92
|
+
The mean-squared error (MSE) metric.
|
|
93
|
+
|
|
94
|
+
Notes
|
|
95
|
+
-----
|
|
96
|
+
.. versionchanged:: 0.16
|
|
97
|
+
This function was renamed from ``skimage.measure.compare_mse`` to
|
|
98
|
+
``skimage.metrics.mean_squared_error``.
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
check_shape_equality(image0, image1)
|
|
102
|
+
image0, image1 = _as_floats(image0, image1)
|
|
103
|
+
return np.mean((image0 - image1) ** 2, dtype=np.float64)
|
|
104
|
+
|
|
105
|
+
def peak_signal_noise_ratio(image_true, image_test, *, data_range=None):
|
|
106
|
+
"""
|
|
107
|
+
Compute the peak signal to noise ratio (PSNR) for an image.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
image_true : ndarray
|
|
112
|
+
Ground-truth image, same shape as im_test.
|
|
113
|
+
image_test : ndarray
|
|
114
|
+
Test image.
|
|
115
|
+
data_range : int, optional
|
|
116
|
+
The data range of the input image (distance between minimum and
|
|
117
|
+
maximum possible values). By default, this is estimated from the image
|
|
118
|
+
data-type.
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
psnr : float
|
|
123
|
+
The PSNR metric.
|
|
124
|
+
|
|
125
|
+
Notes
|
|
126
|
+
-----
|
|
127
|
+
.. versionchanged:: 0.16
|
|
128
|
+
This function was renamed from ``skimage.measure.compare_psnr`` to
|
|
129
|
+
``skimage.metrics.peak_signal_noise_ratio``.
|
|
130
|
+
|
|
131
|
+
References
|
|
132
|
+
----------
|
|
133
|
+
.. [1] https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
|
134
|
+
|
|
135
|
+
"""
|
|
136
|
+
check_shape_equality(image_true, image_test)
|
|
137
|
+
|
|
138
|
+
if data_range is None:
|
|
139
|
+
if image_true.dtype != image_test.dtype:
|
|
140
|
+
print(
|
|
141
|
+
"Inputs have mismatched dtype. Setting data_range based on "
|
|
142
|
+
"image_true."
|
|
143
|
+
)
|
|
144
|
+
dmin, dmax = dtype_range[image_true.dtype.type]
|
|
145
|
+
true_min, true_max = np.min(image_true), np.max(image_true)
|
|
146
|
+
if true_max > dmax or true_min < dmin:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"image_true has intensity values outside the range expected "
|
|
149
|
+
"for its data type. Please manually specify the data_range."
|
|
150
|
+
)
|
|
151
|
+
if true_min >= 0:
|
|
152
|
+
# most common case (255 for uint8, 1 for float)
|
|
153
|
+
data_range = dmax
|
|
154
|
+
else:
|
|
155
|
+
data_range = dmax - dmin
|
|
156
|
+
|
|
157
|
+
image_true, image_test = _as_floats(image_true, image_test)
|
|
158
|
+
|
|
159
|
+
err = mean_squared_error(image_true, image_test)
|
|
160
|
+
data_range = float(data_range) # prevent overflow for small integer types
|
|
161
|
+
return 10 * np.log10((data_range**2) / err)
|
|
162
|
+
#################
|
|
163
|
+
#################
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def measure_gpu_memory(device):
|
|
167
|
+
if torch.cuda.is_available():
|
|
168
|
+
allocated_memory = torch.cuda.memory_allocated(device)
|
|
169
|
+
max_allocated_memory = torch.cuda.max_memory_allocated(device)
|
|
170
|
+
print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
|
|
171
|
+
# print(f"Max allocated memory: {max_allocated_memory / 1024**2:.2f} MB")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def get_model_checkpoint_size_mb(checkpoint_path, return_param_count=False):
|
|
175
|
+
'''
|
|
176
|
+
checkpoint_path: full path containing file extension.
|
|
177
|
+
'''
|
|
178
|
+
checkpoint = torch.load(checkpoint_path, weights_only=True)["model_state_dict"]
|
|
179
|
+
size_model = 0
|
|
180
|
+
# for k,v in checkpoint.items():
|
|
181
|
+
# print(k, v.shape)
|
|
182
|
+
param_count = 0
|
|
183
|
+
for param in checkpoint.values():
|
|
184
|
+
# if param.is_floating_point():
|
|
185
|
+
size_model += param.nelement() * param.element_size()
|
|
186
|
+
param_count += param.nelement()
|
|
187
|
+
# else:
|
|
188
|
+
# size_model += param.numel() * torch.iinfo(param.dtype).bits
|
|
189
|
+
print(f"\tmodel size: {(size_model / 1e6):.3f} MB")
|
|
190
|
+
print(f"\tnum params: {param_count}")
|
|
191
|
+
if return_param_count:
|
|
192
|
+
return size_model / 1e6, param_count
|
|
193
|
+
return size_model / 1e6
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def mean_relative_l2(pred, target, eps=0.01):
|
|
197
|
+
loss = (pred - target.to(pred.dtype))**2 / (pred.detach()**2 + eps)
|
|
198
|
+
return loss.mean()
|
|
199
|
+
|
|
200
|
+
def mape_loss(pred, target, reduction='mean'):
|
|
201
|
+
# pred, target: [B, 1], torch tenspr
|
|
202
|
+
difference = (pred - target).abs()
|
|
203
|
+
scale = 1 / (target.abs() + 1e-2)
|
|
204
|
+
loss = difference * scale
|
|
205
|
+
|
|
206
|
+
if reduction == 'mean':
|
|
207
|
+
loss = loss.mean()
|
|
208
|
+
|
|
209
|
+
return loss
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
### marching cubes helpers from torch-ngp ###
|
|
213
|
+
def custom_meshgrid(*args):
|
|
214
|
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
|
215
|
+
if packaging.version.parse(torch.__version__) < packaging.version.parse('1.10'):
|
|
216
|
+
return torch.meshgrid(*args)
|
|
217
|
+
else:
|
|
218
|
+
return torch.meshgrid(*args, indexing='ij')
|
|
219
|
+
|
|
220
|
+
def extract_fields(bound_min, bound_max, resolution, query_func, device=torch.device("cpu")):
|
|
221
|
+
N = 64
|
|
222
|
+
X = torch.linspace(bound_min[0], bound_max[0], resolution).to(device).split(N)
|
|
223
|
+
Y = torch.linspace(bound_min[1], bound_max[1], resolution).to(device).split(N)
|
|
224
|
+
Z = torch.linspace(bound_min[2], bound_max[2], resolution).to(device).split(N)
|
|
225
|
+
|
|
226
|
+
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
|
227
|
+
with torch.no_grad():
|
|
228
|
+
for xi, xs in enumerate(X):
|
|
229
|
+
for yi, ys in enumerate(Y):
|
|
230
|
+
for zi, zs in enumerate(Z):
|
|
231
|
+
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
|
232
|
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3]
|
|
233
|
+
val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [N, 1] --> [x, y, z]
|
|
234
|
+
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
|
|
235
|
+
return u
|
|
236
|
+
|
|
237
|
+
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
|
|
238
|
+
u = extract_fields(bound_min, bound_max, resolution, query_func)
|
|
239
|
+
vertices, triangles = mcubes.marching_cubes(u, threshold)
|
|
240
|
+
|
|
241
|
+
b_max_np = bound_max.detach().cpu().numpy()
|
|
242
|
+
b_min_np = bound_min.detach().cpu().numpy()
|
|
243
|
+
|
|
244
|
+
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
|
245
|
+
return vertices, triangles
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def save_mesh(save_path, sdf_model, device, resolution=256):
|
|
250
|
+
print(f"==> Saving mesh to {save_path}")
|
|
251
|
+
|
|
252
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
253
|
+
|
|
254
|
+
def query_func(pts):
|
|
255
|
+
pts = pts.to(device)
|
|
256
|
+
with torch.no_grad():
|
|
257
|
+
# with torch.cuda.amp.autocast(enabled=False):
|
|
258
|
+
sdfs = sdf_model(pts)
|
|
259
|
+
return sdfs
|
|
260
|
+
|
|
261
|
+
bounds_min = torch.FloatTensor([-1, -1, -1])
|
|
262
|
+
bounds_max = torch.FloatTensor([1, 1, 1])
|
|
263
|
+
|
|
264
|
+
vertices, triangles = extract_geometry(bounds_min, bounds_max, resolution=resolution, threshold=0, query_func=query_func)
|
|
265
|
+
|
|
266
|
+
mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
|
|
267
|
+
mesh.export(save_path)
|
|
268
|
+
|
|
269
|
+
print(f"==> Finished saving mesh.")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: lora-nf
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Low-Rank Adaptation of Neural Fields
|
|
5
|
+
Home-page: https://github.com/dinhanhtruong/LoRA-NF
|
|
6
|
+
Author: Anh Truong
|
|
7
|
+
License: MIT License
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
9
|
+
Requires-Dist: torch
|
|
10
|
+
Requires-Dist: numpy
|
|
11
|
+
Requires-Dist: imageio
|
|
12
|
+
Requires-Dist: trimesh
|
|
13
|
+
Requires-Dist: pysdf
|
|
14
|
+
Requires-Dist: PyMCubes
|
|
15
|
+
Requires-Dist: packaging
|
|
16
|
+
Dynamic: author
|
|
17
|
+
Dynamic: classifier
|
|
18
|
+
Dynamic: home-page
|
|
19
|
+
Dynamic: license
|
|
20
|
+
Dynamic: requires-dist
|
|
21
|
+
Dynamic: summary
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
setup.py
|
|
3
|
+
lora_nf/__init__.py
|
|
4
|
+
lora_nf/common.py
|
|
5
|
+
lora_nf/data_samplers.py
|
|
6
|
+
lora_nf/torch_modules.py
|
|
7
|
+
lora_nf/train_lora.py
|
|
8
|
+
lora_nf/util.py
|
|
9
|
+
lora_nf.egg-info/PKG-INFO
|
|
10
|
+
lora_nf.egg-info/SOURCES.txt
|
|
11
|
+
lora_nf.egg-info/dependency_links.txt
|
|
12
|
+
lora_nf.egg-info/requires.txt
|
|
13
|
+
lora_nf.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
lora_nf
|
lora_nf-0.1.1/setup.cfg
ADDED
lora_nf-0.1.1/setup.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from setuptools import setup
|
|
2
|
+
|
|
3
|
+
setup(
|
|
4
|
+
name='lora-nf',
|
|
5
|
+
version='0.1.1',
|
|
6
|
+
description='Low-Rank Adaptation of Neural Fields',
|
|
7
|
+
url='https://github.com/dinhanhtruong/LoRA-NF',
|
|
8
|
+
author='Anh Truong',
|
|
9
|
+
license='MIT License',
|
|
10
|
+
packages=['lora_nf'],
|
|
11
|
+
install_requires=['torch',
|
|
12
|
+
'numpy',
|
|
13
|
+
'imageio',
|
|
14
|
+
'trimesh',
|
|
15
|
+
'pysdf',
|
|
16
|
+
'PyMCubes',
|
|
17
|
+
'packaging'],
|
|
18
|
+
classifiers=[
|
|
19
|
+
'Programming Language :: Python :: 3.12',
|
|
20
|
+
],
|
|
21
|
+
)
|