gaussian-splatting 1.18.0__cp312-cp312-win_amd64.whl → 1.19.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,9 +8,17 @@ from .trainer.extensions import ScaleRegularizeTrainerWrapper
8
8
 
9
9
  def prepare_dataset(source: str, device: str, trainable_camera: bool = False, load_camera: str = None, load_mask=True, load_depth=True) -> CameraDataset:
10
10
  if trainable_camera:
11
- dataset = (TrainableCameraDataset.from_json(load_camera, load_mask=load_mask, load_depth=load_depth) if load_camera else ColmapTrainableCameraDataset(source, load_mask=load_mask, load_depth=load_depth)).to(device)
11
+ dataset = (
12
+ TrainableCameraDataset.from_json(load_camera, load_mask=load_mask, load_depth=load_depth)
13
+ if load_camera else
14
+ ColmapTrainableCameraDataset(source, load_mask=load_mask, load_depth=load_depth)
15
+ ).to(device)
12
16
  else:
13
- dataset = (FixedTrainableCameraDataset(load_camera, load_mask=load_mask, load_depth=load_depth) if load_camera else ColmapCameraDataset(source, load_mask=load_mask, load_depth=load_depth)).to(device)
17
+ dataset = (
18
+ FixedTrainableCameraDataset(load_camera, load_mask=load_mask, load_depth=load_depth)
19
+ if load_camera else
20
+ ColmapCameraDataset(source, load_mask=load_mask, load_depth=load_depth)
21
+ ).to(device)
14
22
  return dataset
15
23
 
16
24
 
@@ -15,8 +15,8 @@ from gaussian_splatting.prepare import prepare_dataset, prepare_gaussians
15
15
  def prepare_rendering(
16
16
  sh_degree: int, source: str, device: str,
17
17
  trainable_camera: bool = False, load_ply: str = None, load_camera: str = None,
18
- load_depth=False) -> Tuple[CameraDataset, GaussianModel]:
19
- dataset = prepare_dataset(source=source, device=device, trainable_camera=trainable_camera, load_camera=load_camera, load_mask=False, load_depth=load_depth)
18
+ load_depth=True) -> Tuple[CameraDataset, GaussianModel]:
19
+ dataset = prepare_dataset(source=source, device=device, trainable_camera=trainable_camera, load_camera=load_camera, load_mask=True, load_depth=load_depth)
20
20
  gaussians = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply)
21
21
  return dataset, gaussians
22
22
 
@@ -15,7 +15,7 @@ from gaussian_splatting.prepare import basemodes, shliftmodes, prepare_dataset,
15
15
  def prepare_training(
16
16
  sh_degree: int, source: str, device: str, mode: str,
17
17
  trainable_camera: bool = False, load_ply: str = None, load_camera: str = None,
18
- load_mask=False, load_depth=False,
18
+ load_mask=True, load_depth=True,
19
19
  with_scale_reg=False, configs={}) -> Tuple[CameraDataset, GaussianModel, AbstractTrainer]:
20
20
  dataset = prepare_dataset(source=source, device=device, trainable_camera=trainable_camera, load_camera=load_camera, load_mask=load_mask, load_depth=load_depth)
21
21
  gaussians = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply)
@@ -33,6 +33,10 @@ class AbstractTrainer(ABC):
33
33
  def schedulers(self) -> Dict[str, Callable[[int], float]]:
34
34
  raise ValueError("Schedulers is not set")
35
35
 
36
+ @abstractmethod
37
+ def preprocess(self, camera: Camera) -> Camera:
38
+ pass
39
+
36
40
  @abstractmethod
37
41
  def loss(self, out: dict, camera: Camera) -> torch.Tensor:
38
42
  pass
@@ -49,6 +53,7 @@ class AbstractTrainer(ABC):
49
53
 
50
54
  def step(self, camera: Camera):
51
55
  self.update_learning_rate()
56
+ camera = self.preprocess(camera)
52
57
  out = self.model(camera)
53
58
  loss = self.loss(out, camera)
54
59
  loss.backward()
@@ -100,6 +105,9 @@ class TrainerWrapper(AbstractTrainer):
100
105
  def schedulers(self) -> Dict[str, Callable[[int], float]]:
101
106
  return self.base_trainer.schedulers
102
107
 
108
+ def preprocess(self, camera: Camera) -> Camera:
109
+ return self.base_trainer.preprocess(camera)
110
+
103
111
  def loss(self, out: dict, camera: Camera) -> torch.Tensor:
104
112
  return self.base_trainer.loss(out, camera)
105
113
 
@@ -22,11 +22,23 @@ class BaseTrainer(AbstractTrainer):
22
22
  opacity_lr=0.025,
23
23
  scaling_lr=0.005,
24
24
  rotation_lr=0.001,
25
- mask_mode="ignore", # "ignore", "random", "bg_color"
25
+ mask_mode="none",
26
+ # "none"=do not use mask
27
+ # "ignore"=loss of the masked area will be set to 0
28
+ # "bg_color"=fill the masked area of ground truth with the bg_color for rendering
29
+ bg_color=None,
30
+ # None=do not change bg_color
31
+ # "random"=set bg_color to random color
32
+ # tuple(float, float, float)=set bg_color to the given color
26
33
  ):
27
34
  super().__init__()
28
35
  self.lambda_dssim = lambda_dssim
36
+ assert mask_mode in ["none", "ignore", "bg_color"], f"Unknown mask policy: {mask_mode}"
37
+ assert bg_color is None or bg_color == "random" or (
38
+ isinstance(bg_color, tuple) and len(bg_color) == 3 and all(isinstance(c, float) for c in bg_color)
39
+ ), f"bg_color must be 'random' or a RGB value tuple(float, float, float), got {bg_color}"
29
40
  self.mask_mode = mask_mode
41
+ self.bg_color = bg_color
30
42
  params = [
31
43
  {'params': [model._xyz], 'lr': position_lr_init * scene_extent, "name": "xyz"},
32
44
  {'params': [model._features_dc], 'lr': feature_lr, "name": "f_dc"},
@@ -69,21 +81,33 @@ class BaseTrainer(AbstractTrainer):
69
81
  def schedulers(self) -> Dict[str, Callable[[int], float]]:
70
82
  return self._schedulers
71
83
 
84
+ def preprocess(self, camera: Camera) -> Camera:
85
+ if self.bg_color == "random":
86
+ camera = camera._replace(bg_color=torch.rand_like(camera.bg_color))
87
+ elif isinstance(self.bg_color, tuple):
88
+ camera = camera._replace(bg_color=torch.tensor(self.bg_color, device=camera.bg_color.device, dtype=camera.bg_color.dtype))
89
+ elif self.bg_color is None:
90
+ pass
91
+ else:
92
+ raise ValueError(f"bg_color must be 'random' or a tuple(int, int, int), got {self.bg_color}")
93
+ return camera
94
+
72
95
  def loss(self, out: dict, camera: Camera) -> torch.Tensor:
73
96
  render = out["render"]
74
97
  gt = camera.ground_truth_image
75
98
  mask = camera.ground_truth_image_mask
76
- if mask is not None:
77
- match self.mask_mode:
78
- case "ignore":
79
- render = render * mask.unsqueeze(0)
80
- gt = gt * mask.unsqueeze(0)
81
- case "random":
82
- gt = gt * mask.unsqueeze(0) + (1 - mask.unsqueeze(0)) * torch.rand_like(gt)
83
- case "bg_color":
84
- gt = gt * mask.unsqueeze(0) + (1 - mask.unsqueeze(0)) * camera.bg_color.unsqueeze(-1).unsqueeze(-1)
85
- case _:
86
- raise ValueError(f"Unknown mask policy: {self.mask_mode}")
99
+ match self.mask_mode:
100
+ case "none":
101
+ pass
102
+ case "ignore":
103
+ assert mask is not None, "Mask is required for 'ignore' mask policy"
104
+ render = render * mask.unsqueeze(0)
105
+ gt = gt * mask.unsqueeze(0)
106
+ case "bg_color":
107
+ assert mask is not None, "Mask is required for 'bg_color' mask policy"
108
+ gt = gt * mask.unsqueeze(0) + (1 - mask.unsqueeze(0)) * camera.bg_color.unsqueeze(-1).unsqueeze(-1)
109
+ case _:
110
+ raise ValueError(f"Unknown mask policy: {self.mask_mode}")
87
111
  Ll1 = l1_loss(render, gt)
88
112
  ssim_value = ssim(render, gt)
89
113
  loss = (1.0 - self.lambda_dssim) * Ll1 + self.lambda_dssim * (1.0 - ssim_value)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gaussian_splatting
3
- Version: 1.18.0
3
+ Version: 1.19.0
4
4
  Summary: Refactored python training and inference code for 3D Gaussian Splatting
5
5
  Home-page: https://github.com/yindaheng98/gaussian-splatting
6
6
  Author: yindaheng98
@@ -2,9 +2,9 @@ gaussian_splatting/__init__.py,sha256=CiOZMcyPTAaKtEuMZUhEda_Ad4_RUhmIstB-A3iuOJ
2
2
  gaussian_splatting/camera.py,sha256=vo7mu6lyFpIhDqOAgNiJuPan8t_nDJn5cJkAYygLFcA,9243
3
3
  gaussian_splatting/camera_trainable.py,sha256=nI6hFFRV2ev7VwLlKUbzEdN9zUmngYZAANGLr1p1yBA,3841
4
4
  gaussian_splatting/gaussian_model.py,sha256=_Dy_dDa2prALhVgg428a-O8-8PODg3c_JPkOJJ8X4o8,13275
5
- gaussian_splatting/prepare.py,sha256=GWPRpufg5larcKGNwlRtN22xVs9fp3ptMu11rQeySX8,3141
6
- gaussian_splatting/render.py,sha256=sh67INAUhEy5lfkmuB6RBTqe08NF-8MMR7QXTy4Ogg8,5990
7
- gaussian_splatting/train.py,sha256=no945bMVVRylvkkllQIURDplSwm8EydvUhn-wD9Cn2k,5388
5
+ gaussian_splatting/prepare.py,sha256=rgwdDhPU-sS57ZgTLKwzognbS1hfR2JtjZqlC4FqPbI,3241
6
+ gaussian_splatting/render.py,sha256=MSvJiOpyIr5IgJeaB6xj2GDcK4Jb6mcj7CWVXb-ajIk,5988
7
+ gaussian_splatting/train.py,sha256=a3vHZf5sX5eVsxv394pmwTSon1mPevQVPGAWwYHqaY8,5386
8
8
  gaussian_splatting/dataset/__init__.py,sha256=-runuT-61P0YVpfV_WXqwUZM1oY0N012YH13Bt3rzSU,138
9
9
  gaussian_splatting/dataset/camera_trainable.py,sha256=Kd8v-_ZJ9dLIQ2QyVOXbmouYf5QjbgOgHNRHVpkgCms,5041
10
10
  gaussian_splatting/dataset/dataset.py,sha256=0tmIZ5P7kOEdABiEAXPznkRN91e5rcT5VsAzOLoOuEM,2392
@@ -12,12 +12,12 @@ gaussian_splatting/dataset/colmap/__init__.py,sha256=YEYT2k2WJSqrkkZq4KAJYS9UMgq
12
12
  gaussian_splatting/dataset/colmap/dataset.py,sha256=0UBQ6ynOqElHZSphJ-MSbYQqCwwYZaAXl1y9AY5YKuY,4720
13
13
  gaussian_splatting/dataset/colmap/params_init.py,sha256=6_6gZ0Wl4aZrps2PJ_U234sxW5D-vOTfwioVa1FWC-E,1802
14
14
  gaussian_splatting/dataset/colmap/read_write_model.py,sha256=TenI7ai5UV7Ksg2vAXvJWnYFwOOo1tlS_633RfCLuQU,23137
15
- gaussian_splatting/diff_gaussian_rasterization/_C.cp312-win_amd64.pyd,sha256=BMYDyte6huIWUJqhWk3ofZ6jPYLkYgc83R2UnS65s9w,1300992
15
+ gaussian_splatting/diff_gaussian_rasterization/_C.cp312-win_amd64.pyd,sha256=eMoDrEyUc6FJtjdn_GA6aO1-m9X7e3eYrrjRRNtA8T4,1300992
16
16
  gaussian_splatting/diff_gaussian_rasterization/__init__.py,sha256=a9D0IZiPx-Mk1795hSq54T-NYT4MtEN_MZrxeMhw0Eo,6705
17
- gaussian_splatting/simple_knn/_C.cp312-win_amd64.pyd,sha256=rBcv-SSPzFMB419uG8FB5mbKqz0h8TCMxOXxO65j1SI,1168896
17
+ gaussian_splatting/simple_knn/_C.cp312-win_amd64.pyd,sha256=mviHoj2YPeMyUf8t472edMWVWWOPGS-U6_M9IEIloaw,1168896
18
18
  gaussian_splatting/trainer/__init__.py,sha256=962fEY8A0spSQn5de_d_LkPOjA1PYKrLbuAkxwZo7mI,940
19
- gaussian_splatting/trainer/abc.py,sha256=kpYnJjLOhsyhE-V2J79EC9nih6MYBcXkmK9cHUA-3ao,4022
20
- gaussian_splatting/trainer/base.py,sha256=gO1x4m82xrZNl8NZVw2CWYqIvZJIMUWmBtPZQPeyxJ0,3370
19
+ gaussian_splatting/trainer/abc.py,sha256=_gcqmEobhSOdZnMyNb2oKS6cZJ-Mg3oYL4xJ5Y3_oic,4262
20
+ gaussian_splatting/trainer/base.py,sha256=GfPifKjelRoCOXLJP6_tk1jgFWTJloBaCyUiVCcE9u4,4724
21
21
  gaussian_splatting/trainer/camera_trainable.py,sha256=TBQXn2f578qeizPz6tgqFm-GRvttv9duuB1xx7_J9TQ,4567
22
22
  gaussian_splatting/trainer/combinations.py,sha256=7NX4fXdDOx8ri1_mgAaWNx-YVdo5XsqMlr9qy-Ll2MM,5329
23
23
  gaussian_splatting/trainer/depth.py,sha256=PxWBSNxzoQcRfCFI_yJnJMS6s8qFWn81CXK6O6ffXL0,7059
@@ -45,8 +45,8 @@ gaussian_splatting/utils/lpipsPyTorch/modules/__init__.py,sha256=47DEQpj8HBSa-_T
45
45
  gaussian_splatting/utils/lpipsPyTorch/modules/lpips.py,sha256=YScu0oXIEstCCjJVRItS_R_csUw70sBMFuP8Syl2UdI,1187
46
46
  gaussian_splatting/utils/lpipsPyTorch/modules/networks.py,sha256=kqIebq7dAhHypTXweFVEf_RDbN7_Zv7O3MlD-CfRvpg,2788
47
47
  gaussian_splatting/utils/lpipsPyTorch/modules/utils.py,sha256=TDcem3E3HqDNN2MT8qlOL_BKVHeO4HRE77JxF-kOWk8,915
48
- gaussian_splatting-1.18.0.dist-info/licenses/LICENSE.md,sha256=bMuRQKn0u485mx8JBBTJ5Simc-aWHaQsxmoB6jsg5oE,4752
49
- gaussian_splatting-1.18.0.dist-info/METADATA,sha256=z6Q1TP39OROC_mlfGLSBgZQUip5neZOFgqQKx1Gbj_c,17183
50
- gaussian_splatting-1.18.0.dist-info/WHEEL,sha256=8UP9x9puWI0P1V_d7K2oMTBqfeLNm21CTzZ_Ptr0NXU,101
51
- gaussian_splatting-1.18.0.dist-info/top_level.txt,sha256=uaYrPYXRHhpybgCnsoazTcdhpzZGnLT_vd5eoRzBWWI,19
52
- gaussian_splatting-1.18.0.dist-info/RECORD,,
48
+ gaussian_splatting-1.19.0.dist-info/licenses/LICENSE.md,sha256=bMuRQKn0u485mx8JBBTJ5Simc-aWHaQsxmoB6jsg5oE,4752
49
+ gaussian_splatting-1.19.0.dist-info/METADATA,sha256=L0gv8LTr-DW4WFGl_jIFl568SuvIqlhsp863yh5LSZI,17183
50
+ gaussian_splatting-1.19.0.dist-info/WHEEL,sha256=8UP9x9puWI0P1V_d7K2oMTBqfeLNm21CTzZ_Ptr0NXU,101
51
+ gaussian_splatting-1.19.0.dist-info/top_level.txt,sha256=uaYrPYXRHhpybgCnsoazTcdhpzZGnLT_vd5eoRzBWWI,19
52
+ gaussian_splatting-1.19.0.dist-info/RECORD,,