tensorneko 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tensorneko/__init__.py CHANGED
@@ -1,11 +1,11 @@
1
1
  import os.path
2
2
 
3
- from tensorneko_util import io
4
3
  from . import backend
5
4
  from . import callback
6
5
  from . import dataset
7
6
  from . import debug
8
7
  from . import evaluation
8
+ from . import io
9
9
  from . import layer
10
10
  from . import module
11
11
  from . import notebook
@@ -13,10 +13,13 @@ from . import optim
13
13
  from . import preprocess
14
14
  from . import util
15
15
  from . import visualization
16
+ from .io import read, write
16
17
  from .neko_model import NekoModel
17
18
  from .neko_module import NekoModule
18
19
  from .neko_trainer import NekoTrainer
19
20
 
21
+ __version__ = io.read.text(os.path.join(util.get_tensorneko_path(), "version.txt"))
22
+
20
23
  __all__ = [
21
24
  "callback",
22
25
  "dataset",
@@ -33,7 +36,9 @@ __all__ = [
33
36
  "debug",
34
37
  "NekoModel",
35
38
  "NekoTrainer",
36
- "NekoModule"
39
+ "NekoModule",
40
+ "read",
41
+ "write",
37
42
  ]
38
43
 
39
- __version__ = io.read.text(os.path.join(util.get_tensorneko_path(), "version.txt"))
44
+
@@ -6,16 +6,18 @@ from torch.optim import Adam
6
6
  from torchmetrics import Accuracy, F1Score, AUROC
7
7
 
8
8
  from ..neko_model import NekoModel
9
+ from ..util import Shape
9
10
 
10
11
 
11
12
  class BinaryClassifier(NekoModel):
12
13
 
13
- def __init__(self, name, model: Module, learning_rate: float = 1e-4, distributed: bool = False):
14
- super().__init__(name)
14
+ def __init__(self, name, model: Module, learning_rate: float = 1e-4, distributed: bool = False,
15
+ input_shape: Optional[Shape] = None
16
+ ):
17
+ super().__init__(name, input_shape=input_shape, distributed=distributed)
15
18
  self.save_hyperparameters()
16
19
  self.model = model
17
20
  self.learning_rate = learning_rate
18
- self.distributed = distributed
19
21
  self.loss_fn = BCEWithLogitsLoss()
20
22
  self.acc_fn = Accuracy(task="binary")
21
23
  self.f1_fn = F1Score(task="binary")
@@ -23,9 +25,9 @@ class BinaryClassifier(NekoModel):
23
25
 
24
26
  @classmethod
25
27
  def from_module(cls, model: Module, learning_rate: float = 1e-4, name: str = "binary_classifier",
26
- distributed: bool = False
28
+ distributed: bool = False, input_shape: Optional[Shape] = None
27
29
  ):
28
- return cls(name, model, learning_rate, distributed)
30
+ return cls(name, model, learning_rate, distributed, input_shape)
29
31
 
30
32
  def forward(self, x):
31
33
  return self.model(x)
@@ -38,7 +40,7 @@ class BinaryClassifier(NekoModel):
38
40
  acc = self.acc_fn(prob, y)
39
41
  f1 = self.f1_fn(prob, y)
40
42
  auc = self.auc_fn(prob, y)
41
- return {"loss": loss, "acc": acc, "f1": f1, "auc": auc}
43
+ return {"loss": loss, "metric/acc": acc, "metric/f1": f1, "metric/auc": auc}
42
44
 
43
45
  def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None,
44
46
  optimizer_idx: Optional[int] = None, hiddens: Optional[Tensor] = None
tensorneko/arch/gan.py CHANGED
@@ -140,7 +140,7 @@ class GAN(NekoModel, ABC):
140
140
 
141
141
  # calculate discriminator loss
142
142
  d_loss = self.d_loss_fn(self.discriminator(d_batch), d_label)
143
- return {"loss": d_loss, "d_loss": d_loss}
143
+ return {"loss": d_loss, "loss/d_loss": d_loss}
144
144
 
145
145
  def g_step(self, x: Tensor, z: Tensor) -> Dict[str, Tensor]:
146
146
  # forward generator
@@ -150,7 +150,7 @@ class GAN(NekoModel, ABC):
150
150
 
151
151
  # calculate generator loss
152
152
  g_loss = self.g_loss_fn(self.discriminator(fake_image), target_label)
153
- return {"loss": g_loss, "g_loss": g_loss}
153
+ return {"loss": g_loss, "loss/g_loss": g_loss}
154
154
 
155
155
  def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None,
156
156
  dataloader_idx: Optional[int] = None
tensorneko/arch/vqvae.py CHANGED
@@ -72,7 +72,7 @@ class VQVAE(NekoModel, ABC):
72
72
  x_hat, embedding_loss = self(x)
73
73
  rec_loss = self.rec_loss_fn(x_hat, x)
74
74
  loss = rec_loss + embedding_loss
75
- return {"loss": loss, "r_loss": rec_loss, "e_loss": embedding_loss}
75
+ return {"loss": loss, "loss/r_loss": rec_loss, "loss/e_loss": embedding_loss}
76
76
 
77
77
  def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None,
78
78
  dataloader_idx: Optional[int] = None
@@ -84,7 +84,7 @@ class VQVAE(NekoModel, ABC):
84
84
  x_hat, embedding_loss = self(x)
85
85
  rec_loss = self.rec_loss_fn(x_hat, x)
86
86
  loss = rec_loss + embedding_loss
87
- return {"loss": loss, "r_loss": rec_loss, "e_loss": embedding_loss}
87
+ return {"loss": loss, "loss/r_loss": rec_loss, "loss/e_loss": embedding_loss}
88
88
 
89
89
  def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: Optional[int] = None) -> Tensor:
90
90
  return self(batch)[0]
tensorneko/arch/wgan.py CHANGED
@@ -100,4 +100,4 @@ class WGAN(GAN, ABC):
100
100
  gp = torch.tensor(0., device=self.device)
101
101
  loss = d_loss
102
102
 
103
- return {"loss": loss, "d_loss": d_loss, "gp": gp}
103
+ return {"loss": loss, "loss/d_loss": d_loss, "loss/gp": gp}
@@ -1,8 +1,10 @@
1
- from tensorneko_util.backend import parallel, run_blocking, VisualLib, AudioLib
1
+ from tensorneko_util.backend import parallel, run_blocking, VisualLib, AudioLib, import_tqdm_auto, import_tqdm
2
2
 
3
3
  __all__ = [
4
4
  "parallel",
5
5
  "run_blocking",
6
6
  "VisualLib",
7
7
  "AudioLib",
8
+ "import_tqdm_auto",
9
+ "import_tqdm",
8
10
  ]
@@ -22,7 +22,7 @@ class EarlyStoppingLR(Callback):
22
22
  return
23
23
  all_lr = []
24
24
  for key, value in metrics.items():
25
- if re.match(r"opt\d+_lr\d+", key):
25
+ if re.match(r"learning_rate/opt\d+_lr\d+", key):
26
26
  all_lr.append(value)
27
27
 
28
28
  if len(all_lr) == 0:
@@ -5,7 +5,7 @@ class EpochNumLogger(Callback):
5
5
  """Log epoch number in each epoch start."""
6
6
 
7
7
  def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
8
- key = "epoch"
8
+ key = "epoch/epoch"
9
9
  value = trainer.current_epoch
10
10
  pl_module.logger.log_metrics({key: value}, step=trainer.global_step)
11
11
  pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed)
@@ -12,5 +12,5 @@ class EpochTimeLogger(Callback):
12
12
 
13
13
  def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
14
14
  elapsed_time = time() - self.start_time
15
- pl_module.logger.log_metrics({"epoch_time": elapsed_time}, step=trainer.global_step)
16
- pl_module.log("epoch_time", elapsed_time, logger=False, sync_dist=pl_module.distributed)
15
+ pl_module.logger.log_metrics({"epoch/epoch_time": elapsed_time}, step=trainer.global_step)
16
+ pl_module.log("epoch/epoch_time", elapsed_time, logger=False, sync_dist=pl_module.distributed)
@@ -29,14 +29,14 @@ class GpuStatsLogger(Callback):
29
29
  return
30
30
  for gpu in self.monitor_epoch.average_stats.gpus:
31
31
  logged_info = {
32
- f"gpu{gpu.index}_memory_used_epoch": gpu.memory_used / 1024,
33
- f"gpu{gpu.index}_memory_total_epoch": gpu.memory_total / 1024,
34
- f"gpu{gpu.index}_memory_util_epoch": gpu.memory_used / gpu.memory_total,
35
- f"gpu{gpu.index}_temperature_epoch": float(gpu.temperature),
36
- f"gpu{gpu.index}_utilization_epoch": gpu.utilization / 100,
37
- f"gpu{gpu.index}_power_draw_epoch": float(gpu.power_draw),
38
- f"gpu{gpu.index}_power_percentage_epoch": gpu.power_draw / gpu.power_limit,
39
- f"gpu{gpu.index}_fan_speed_epoch": float(gpu.fan_speed) if gpu.fan_speed is not None else 0.,
32
+ f"system/gpu{gpu.index}_memory_used_epoch": gpu.memory_used / 1024,
33
+ f"system/gpu{gpu.index}_memory_total_epoch": gpu.memory_total / 1024,
34
+ f"system/gpu{gpu.index}_memory_util_epoch": gpu.memory_used / gpu.memory_total,
35
+ f"system/gpu{gpu.index}_temperature_epoch": float(gpu.temperature),
36
+ f"system/gpu{gpu.index}_utilization_epoch": gpu.utilization / 100,
37
+ f"system/gpu{gpu.index}_power_draw_epoch": float(gpu.power_draw),
38
+ f"system/gpu{gpu.index}_power_percentage_epoch": gpu.power_draw / gpu.power_limit,
39
+ f"system/gpu{gpu.index}_fan_speed_epoch": float(gpu.fan_speed) if gpu.fan_speed is not None else 0.,
40
40
  }
41
41
  pl_module.logger.log_metrics(logged_info, step=trainer.global_step)
42
42
  pl_module.log_dict(logged_info, logger=False, sync_dist=pl_module.distributed)
@@ -55,14 +55,14 @@ class GpuStatsLogger(Callback):
55
55
  return
56
56
  for gpu in self.monitor_step.average_stats.gpus:
57
57
  logged_info = {
58
- f"gpu{gpu.index}_memory_used_step": gpu.memory_used / 1024,
59
- f"gpu{gpu.index}_memory_total_step": gpu.memory_total / 1024,
60
- f"gpu{gpu.index}_memory_util_step": gpu.memory_used / gpu.memory_total,
61
- f"gpu{gpu.index}_temperature_step": float(gpu.temperature),
62
- f"gpu{gpu.index}_utilization_step": gpu.utilization / 100,
63
- f"gpu{gpu.index}_power_draw_step": float(gpu.power_draw),
64
- f"gpu{gpu.index}_power_percentage_step": gpu.power_draw / gpu.power_limit,
65
- f"gpu{gpu.index}_fan_speed_step": float(gpu.fan_speed) if gpu.fan_speed is not None else 0.,
58
+ f"system/gpu{gpu.index}_memory_used_step": gpu.memory_used / 1024,
59
+ f"system/gpu{gpu.index}_memory_total_step": gpu.memory_total / 1024,
60
+ f"system/gpu{gpu.index}_memory_util_step": gpu.memory_used / gpu.memory_total,
61
+ f"system/gpu{gpu.index}_temperature_step": float(gpu.temperature),
62
+ f"system/gpu{gpu.index}_utilization_step": gpu.utilization / 100,
63
+ f"system/gpu{gpu.index}_power_draw_step": float(gpu.power_draw),
64
+ f"system/gpu{gpu.index}_power_percentage_step": gpu.power_draw / gpu.power_limit,
65
+ f"system/gpu{gpu.index}_fan_speed_step": float(gpu.fan_speed) if gpu.fan_speed is not None else 0.,
66
66
  }
67
67
  pl_module.logger.log_metrics(logged_info, step=trainer.global_step)
68
68
  pl_module.log_dict(logged_info, logger=False, sync_dist=pl_module.distributed)
@@ -7,7 +7,7 @@ class LrLogger(Callback):
7
7
  def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
8
8
  for i, optimizer in enumerate(trainer.optimizers):
9
9
  for j, params in enumerate(optimizer.param_groups):
10
- key = f"opt{i}_lr{j}"
10
+ key = f"learning_rate/opt{i}_lr{j}"
11
11
  value = params["lr"]
12
12
  pl_module.logger.log_metrics({key: value}, step=trainer.global_step)
13
13
  pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed)
@@ -23,8 +23,8 @@ class SystemStatsLogger(Callback):
23
23
  cpu_usage = self.psutil.cpu_percent()
24
24
  memory_usage = self.psutil.virtual_memory().percent
25
25
  logged_info = {
26
- "cpu_usage_epoch": cpu_usage,
27
- "memory_usage_epoch": memory_usage
26
+ "system/cpu_usage_epoch": cpu_usage,
27
+ "system/memory_usage_epoch": memory_usage
28
28
  }
29
29
  pl_module.logger.log_metrics(logged_info, step=trainer.global_step)
30
30
  pl_module.log_dict(logged_info, logger=False, sync_dist=pl_module.distributed)
@@ -37,8 +37,8 @@ class SystemStatsLogger(Callback):
37
37
  cpu_usage = self.psutil.cpu_percent()
38
38
  memory_usage = self.psutil.virtual_memory().percent
39
39
  logged_info = {
40
- "cpu_usage_step": cpu_usage,
41
- "memory_usage_step": memory_usage
40
+ "system/cpu_usage_step": cpu_usage,
41
+ "system/memory_usage_step": memory_usage
42
42
  }
43
43
  pl_module.logger.log_metrics(logged_info, step=trainer.global_step)
44
44
  pl_module.log_dict(logged_info, logger=False, sync_dist=pl_module.distributed)
@@ -12,12 +12,13 @@ from torchmetrics.image.fid import FrechetInceptionDistance
12
12
  from torchvision.transforms.functional import resize
13
13
 
14
14
  from tensorneko_util.backend import VisualLib
15
- from tensorneko_util.backend._tqdm import import_tqdm_auto
15
+ from tensorneko_util.backend.tqdm import import_tqdm_auto
16
16
 
17
17
  try:
18
18
  from typing import Literal
19
19
  TypeOption = Literal["video", "image"]
20
20
  except ImportError:
21
+ Literal = None
21
22
  TypeOption = str
22
23
 
23
24
 
tensorneko/io/reader.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  from typing import Type
3
3
 
4
4
  from tensorneko_util.io.reader import Reader as BaseReader
5
+ from .weight import WeightReader
5
6
 
6
7
  try:
7
8
  from .mesh import MeshReader
@@ -15,6 +16,7 @@ class Reader(BaseReader):
15
16
 
16
17
  def __init__(self):
17
18
  super().__init__()
19
+ self.weight = WeightReader
18
20
  self._mesh = None
19
21
 
20
22
  @property
@@ -0,0 +1,2 @@
1
+ from .weight_reader import WeightReader
2
+ from .weight_writer import WeightWriter
@@ -0,0 +1,81 @@
1
+ from typing import OrderedDict
2
+
3
+ import torch
4
+
5
+ from ...util import Device
6
+
7
+
8
+ class WeightReader:
9
+ """WeightReader for read model weights (checkpoints, state_dict, etc)."""
10
+
11
+ @classmethod
12
+ def of_pt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
13
+ """
14
+ Reads PyTorch model weights from a `.pt` or `.pth` file.
15
+
16
+ Args:
17
+ path (``str``): The path of the `.pt` or `.pth` file.
18
+ map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
19
+
20
+ Returns:
21
+ :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
22
+ """
23
+ return torch.load(path, map_location=map_location)
24
+
25
+ @classmethod
26
+ def of_ckpt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
27
+ """
28
+ Reads PyTorch model weights from a `.ckpt` file.
29
+
30
+ Args:
31
+ path (``str``): The path of the `.ckpt` file.
32
+ map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
33
+
34
+ Returns:
35
+ :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
36
+ """
37
+ return torch.load(path, map_location=map_location)["state_dict"]
38
+
39
+ @classmethod
40
+ def of_safetensors(cls, path: str, map_location: str = "cpu") -> OrderedDict[str, torch.Tensor]:
41
+ """
42
+ Reads model weights from a `.safetensors` file.
43
+
44
+ Args:
45
+ path (``str``): The path of the `.safetensors` file.
46
+ map_location (``str``): The location to load the model weights. Default: "cpu"
47
+
48
+ Returns:
49
+ :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
50
+ """
51
+ import safetensors
52
+ from collections import OrderedDict
53
+ tensors = OrderedDict()
54
+ with safetensors.safe_open(path, framework="pt", device=map_location) as f:
55
+ for key in f.keys():
56
+ tensors[key] = f.get_tensor(key)
57
+ return tensors
58
+
59
+ @classmethod
60
+ def of(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
61
+ """
62
+ Reads model weights from a file.
63
+
64
+ Args:
65
+ path (``str``): The path of the file.
66
+ map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
67
+
68
+ Returns:
69
+ :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
70
+ """
71
+
72
+ if path.endswith(".pt") or path.endswith(".pth"):
73
+ return cls.of_pt(path, map_location)
74
+ elif path.endswith(".ckpt"):
75
+ return cls.of_ckpt(path, map_location)
76
+ elif path.endswith(".safetensors"):
77
+ if isinstance(map_location, torch.device):
78
+ map_location = str(map_location)
79
+ return cls.of_safetensors(path, map_location)
80
+ else:
81
+ raise ValueError("Unknown file type. Supported types: .pt, .pth, .ckpt, .safetensors")
@@ -0,0 +1,48 @@
1
+ from typing import Dict
2
+
3
+ import torch
4
+
5
+
6
+ class WeightWriter:
7
+ """WeightWriter for write model weights (checkpoints, state_dict, etc)."""
8
+
9
+ @classmethod
10
+ def to_pt(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
11
+ """
12
+ Writes PyTorch model weights to a `.pt` or `.pth` file.
13
+
14
+ Args:
15
+ path (``str``): The path of the `.pt` or `.pth` file.
16
+ weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
17
+ """
18
+ torch.save(weights, path)
19
+
20
+ @classmethod
21
+ def to_safetensors(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
22
+ """
23
+ Writes model weights to a `.safetensors` file.
24
+
25
+ Args:
26
+ path (``str``): The path of the `.safetensors` file.
27
+ weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
28
+ """
29
+ import safetensors.torch
30
+ safetensors.torch.save_file(weights, path)
31
+
32
+ @classmethod
33
+ def to(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
34
+ """
35
+ Writes model weights to a file.
36
+
37
+ Args:
38
+ path (``str``): The path of the file.
39
+ weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
40
+ """
41
+ file_type = path.split(".")[-1]
42
+
43
+ if file_type == "pt":
44
+ cls.to_pt(path, weights)
45
+ elif file_type == "safetensors":
46
+ cls.to_safetensors(path, weights)
47
+ else:
48
+ raise ValueError("Unknown file type. Supported types: .pt, .safetensors")
tensorneko/io/writer.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from typing import Type
4
4
 
5
5
  from tensorneko_util.io.writer import Writer as BaseWriter
6
+ from .weight import WeightWriter
6
7
 
7
8
  try:
8
9
  from .mesh import MeshWriter
@@ -16,6 +17,7 @@ class Writer(BaseWriter):
16
17
 
17
18
  def __init__(self):
18
19
  super().__init__()
20
+ self.weight = WeightWriter
19
21
  self._mesh = None
20
22
 
21
23
  @property
tensorneko/layer/noise.py CHANGED
@@ -22,7 +22,7 @@ class GaussianNoise(NekoModule):
22
22
  from https://discuss.pytorch.org/t/writing-a-simple-gaussian-noise-layer-in-pytorch/4694/3
23
23
  """
24
24
 
25
- def __init__(self, sigma=0.1, device: Union[Device, str] = "cuda"):
25
+ def __init__(self, sigma=0.1, device: Device = "cuda"):
26
26
  super().__init__()
27
27
  self.sigma = sigma
28
28
  self.noise = torch.tensor(0.).to(device)
tensorneko/neko_model.py CHANGED
@@ -170,6 +170,7 @@ class NekoModel(LightningModule, NekoModule):
170
170
  and self.logger is not None \
171
171
  and self.trainer.global_step % self.trainer.log_every_n_steps == 0:
172
172
  self.log_on_training_step_end(outputs)
173
+ super().on_train_batch_end(outputs, batch, batch_idx)
173
174
 
174
175
  def on_validation_batch_end(
175
176
  self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int = 0
@@ -177,18 +178,21 @@ class NekoModel(LightningModule, NekoModule):
177
178
  """Append the outputs of validation step to the list"""
178
179
  outputs = {k: v.detach() for k, v in outputs.items()}
179
180
  self.validation_step_outputs.append(outputs)
181
+ super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)
180
182
 
181
183
  def on_train_epoch_end(self) -> None:
182
184
  """For each training epoch end, log the metrics"""
183
185
  if self.trainer.log_on_epoch and self.logger is not None:
184
186
  self.log_on_training_epoch_end(self.training_step_outputs)
185
187
  self.training_step_outputs.clear()
188
+ super().on_train_epoch_end()
186
189
 
187
190
  def on_validation_epoch_end(self) -> None:
188
191
  """For each validation epoch end, log the metrics"""
189
192
  if self.logger is not None:
190
193
  self.log_on_validation_epoch_end(self.validation_step_outputs)
191
194
  self.validation_step_outputs.clear()
195
+ super().on_validation_epoch_end()
192
196
 
193
197
  def log_on_training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
194
198
  """Log the training epoch outputs"""
@@ -197,7 +201,8 @@ class NekoModel(LightningModule, NekoModule):
197
201
  getter = summarize_dict_by(key, torch.mean)
198
202
  value = getter(outputs)
199
203
  history_item[key] = value
200
- self.logger.log_metrics({key: value}, step=self.trainer.global_step)
204
+ if self.logger is not None:
205
+ self.logger.log_metrics({key: value}, step=self.trainer.global_step)
201
206
  self.log(key, value, on_epoch=True, on_step=False, logger=False, sync_dist=self.distributed)
202
207
  self.history.append(history_item)
203
208
 
@@ -210,7 +215,8 @@ class NekoModel(LightningModule, NekoModule):
210
215
  getter = summarize_dict_by(key, torch.mean)
211
216
  value = getter(outputs)
212
217
  self.history[-1]["val_" + key] = value
213
- self.logger.log_metrics({key: value}, step=self.trainer.global_step)
218
+ if self.logger is not None:
219
+ self.logger.log_metrics({key: value}, step=self.trainer.global_step)
214
220
  self.log(key, value, on_epoch=True, on_step=False, logger=False, sync_dist=self.distributed)
215
221
  self.log(f"val_{key}", value, on_epoch=True, on_step=False, logger=False, prog_bar=True,
216
222
  sync_dist=self.distributed)
@@ -220,7 +226,8 @@ class NekoModel(LightningModule, NekoModule):
220
226
  history_item = {}
221
227
  for key, value in output.items():
222
228
  history_item[key] = value
223
- self.logger.log_metrics({key: value}, step=self.trainer.global_step)
229
+ if self.logger is not None:
230
+ self.logger.log_metrics({key: value}, step=self.trainer.global_step)
224
231
  self.log(key, value, on_epoch=False, on_step=True, logger=False, prog_bar=key == "loss",
225
232
  sync_dist=self.distributed)
226
233
  self.history.append(history_item)
@@ -230,6 +237,7 @@ class NekoModel(LightningModule, NekoModule):
230
237
  ) -> None:
231
238
  """For each test step end, log the metrics"""
232
239
  self.log_dict(outputs, sync_dist=self.distributed)
240
+ super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)
233
241
 
234
242
  def log_image(self, name: str, image: torch.Tensor) -> None:
235
243
  """Log an image to the logger"""
@@ -238,6 +246,15 @@ class NekoModel(LightningModule, NekoModule):
238
246
 
239
247
  self.logger.experiment.add_image(name, torch.clip(image, 0, 1), self.trainer.global_step)
240
248
 
249
+ def log_histogram(self, name: str, values: torch.Tensor) -> None:
250
+ """Log a histogram to the logger"""
251
+ if self.logger is None:
252
+ return
253
+
254
+ self.logger.experiment.add_histogram(name, values, self.trainer.global_step)
255
+
241
256
  def on_train_start(self) -> None:
242
257
  """Log the model graph to tensorboard when the input shape is set"""
243
- self.logger.log_graph(self, self.example_input_array)
258
+ if self.can_plot_graph and self.logger is not None:
259
+ self.logger.log_graph(self, self.example_input_array)
260
+ super().on_train_start()
tensorneko/util/type.py CHANGED
@@ -11,7 +11,7 @@ ModuleFactory = Union[Callable[[], Module], Callable[[int], Module]]
11
11
  """The module builder type of ``() -> torch.nn.Module | (int) -> torch.nn.Module``"""
12
12
 
13
13
 
14
- Device = device
14
+ Device = Union[str, device]
15
15
  """Device type of :class:`torch.device`"""
16
16
 
17
17
 
tensorneko/version.txt CHANGED
@@ -1 +1 @@
1
- 0.3.7
1
+ 0.3.9
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tensorneko
3
- Version: 0.3.7
3
+ Version: 0.3.9
4
4
  Summary: Tensor Neural Engine Kompanion. An util library based on PyTorch and PyTorch Lightning.
5
5
  Home-page: https://github.com/ControlNet/tensorneko
6
6
  Author: ControlNet
@@ -33,7 +33,7 @@ Requires-Dist: pillow >=8.1
33
33
  Requires-Dist: av >=8.0.3
34
34
  Requires-Dist: numpy >=1.20.1
35
35
  Requires-Dist: einops >=0.3.0
36
- Requires-Dist: tensorneko-util ==0.3.7
36
+ Requires-Dist: tensorneko-util ==0.3.9
37
37
  Requires-Dist: pysoundfile >=0.9.0 ; platform_system == "Windows"
38
38
 
39
39
  <h1 style="text-align: center">TensorNeko</h1>
@@ -1,24 +1,24 @@
1
- tensorneko/__init__.py,sha256=VPPK00Kduwi84QHnZKBZm8kBRdnPAji6f7J-adYAp_Y,770
2
- tensorneko/neko_model.py,sha256=bLgDMSLw4UDaZXY-7ZQ_f_lBxpsVROm-8FvGMnLvRaA,9769
1
+ tensorneko/__init__.py,sha256=lLVC4StQ5q9OFJTceNrJbj-CIFHFMRvFQyGYJTjuRM4,812
2
+ tensorneko/neko_model.py,sha256=QTbdOAg9ki0ix6mDL_Qu8Wmd5WJOoUFF3M1SXEp3KGc,10551
3
3
  tensorneko/neko_module.py,sha256=qELXvguSjWo_NvcRQibiFl0Qauzd9JWLSnT4dbGNS3Y,1473
4
4
  tensorneko/neko_trainer.py,sha256=JC8qoKSZ5ngz3grf3S0SjvIFVktDIP_GExth5aFfbGA,10074
5
- tensorneko/version.txt,sha256=K_gLRhLh6zul984g6rPv-L_EY9Hz2UtDN95TULjqVQ8,5
5
+ tensorneko/version.txt,sha256=JCuzYNF_l66JsKuMd57Z4WJC2x1-36y0IfHrRwznDqM,5
6
6
  tensorneko/arch/__init__.py,sha256=w4lTUeyBIZelrnSjlBFWUF0erzOmBFl9FqeWQuSOyKs,248
7
7
  tensorneko/arch/auto_encoder.py,sha256=j6PWWyaNYaYNtw_zZ9ikzhCASqe9viXR3JGBIXSK92Y,2137
8
- tensorneko/arch/binary_classifier.py,sha256=x3fQxQ0igyQ48mPB6vjPcDI38Q9WsbBJ--eOpldzgeI,2159
9
- tensorneko/arch/gan.py,sha256=ZAw6bNBXuTWmmC5rKpa7jgMisfpX-ti7gzyYkkh0_ls,7205
10
- tensorneko/arch/vqvae.py,sha256=02bHKJljBg6DTUfghxS3k-T5nOgYknhDU1Em0nirsj0,3730
11
- tensorneko/arch/wgan.py,sha256=k88x3ZtqqqKc0pIv6hiVqhpq-SitrVxrl7q8Etyqmpo,4712
12
- tensorneko/backend/__init__.py,sha256=hWEoQo3Bh7Fn7cfaJznA3k1sHWAUIJ9KAbDxOVjySxU,164
8
+ tensorneko/arch/binary_classifier.py,sha256=1MkEbReXKLdDksRG5Rsife40grJk08EVDcNKp54Xvb4,2316
9
+ tensorneko/arch/gan.py,sha256=bHh8s9UxHNSj9N6C6SllAaVJ61JhbY4ltwmQ-a1543w,7215
10
+ tensorneko/arch/vqvae.py,sha256=_HJoDVBonB4PHymaV9JaAecS-NJZ1jfrKAhoqcLzf0g,3750
11
+ tensorneko/arch/wgan.py,sha256=Oj1vZPXzhgTPHGGXGmbStgylskjjzulr6tYxj7E0zpQ,4722
12
+ tensorneko/backend/__init__.py,sha256=ppJhb1MC_WK6XLK2fX8x0z7Cn-8CkF97x67XBLoRCys,238
13
13
  tensorneko/callback/__init__.py,sha256=H1jOTsSYm9c4sxtcV9_uzumXZ95b4gbiftKqimNER7s,556
14
14
  tensorneko/callback/display_metrics_callback.py,sha256=qzhHcb68B7o9byfD1ZqEitSVkrwsSGOF-u59_Ip9dEg,318
15
- tensorneko/callback/earlystop_lr.py,sha256=mL27eghHXigohA3FZgcy7vObxXiaqHHoYkzw7dsgLf8,1181
16
- tensorneko/callback/epoch_num_logger.py,sha256=54BzG_Ez7hZHgx9Xvc4c42lMs9oppZivaEZIzIzaiYA,456
17
- tensorneko/callback/epoch_time_logger.py,sha256=PwOFvlkYk1mkGzpBF__FMzNMGPyKaZRyCQN2DeG0kMQ,645
18
- tensorneko/callback/gpu_stats_logger.py,sha256=vhX0uElDEDeoNrTl-hYdKhfMkE7pw4rdhmw6d8DEfeo,3420
19
- tensorneko/callback/lr_logger.py,sha256=4nC_teyCX3wmlELrJPq3TGrt2KssRpmgDRyep0h2J2c,605
15
+ tensorneko/callback/earlystop_lr.py,sha256=d0G3NHYi-tNNCdaB7Rt7vVceI8CXOGT4zBTJYNciyug,1195
16
+ tensorneko/callback/epoch_num_logger.py,sha256=efp8vANiWT3eRz-TqhbML5fc7d1h2jcmHSV6OJyFwFg,462
17
+ tensorneko/callback/epoch_time_logger.py,sha256=ktZSEinDBEstLbCGV5pjlQrkmSlVrZ0ymAECoxGX4rs,657
18
+ tensorneko/callback/gpu_stats_logger.py,sha256=qbkE47RKwRTCoI5YTu_9DFMMHDqmZl4n9slMa7cLQZo,3532
19
+ tensorneko/callback/lr_logger.py,sha256=28xmAZ_UOFg0wqg1VjJoifwjzvBsOs-g2nd4bog9ozw,619
20
20
  tensorneko/callback/nil_callback.py,sha256=-vKhOG3Ysv_ZToOdyYEkcZ8h0so9rBRY10f1OIoHeZs,131
21
- tensorneko/callback/system_stats_logger.py,sha256=RGiVz24N9P0I_-M9gmouGNjwvLwmFZd_ifSCsw-yZvc,1754
21
+ tensorneko/callback/system_stats_logger.py,sha256=dS4AOhEADU5cop6vSW5HnVew58jO9SdzKh4lkyssgcE,1782
22
22
  tensorneko/dataset/__init__.py,sha256=6980ci9Ce57HSyhzrKMJfDz31PCQxifVz1aSf63JEsA,247
23
23
  tensorneko/dataset/list_dataset.py,sha256=UhgTHapi7dQF4nEgnZH73oCLzluNAJU--N7Fhaa8P8s,1237
24
24
  tensorneko/dataset/nested_dataset.py,sha256=lMGW7ODfvBn-aRd7c7HftK6BQUZGCUVkm4XJh4iiPBg,2073
@@ -28,17 +28,20 @@ tensorneko/dataset/sampler/sequential_iter_sampler.py,sha256=DxBwSoWjYlq6kA6g-54
28
28
  tensorneko/debug/__init__.py,sha256=ZMfU3qquhMhl6EgPzM7Yuvvv0PWy3cR39UjPrrSmQcs,163
29
29
  tensorneko/evaluation/__init__.py,sha256=jW8dh1JRMpx3npjTp7wJLzz-IxFZTBh7F-Ztfoep9xs,296
30
30
  tensorneko/evaluation/enum.py,sha256=s3P8XAobku-as4in5vh6BanvVW5Ccwnff0t124lVFFg,137
31
- tensorneko/evaluation/fid.py,sha256=mDsgh7Ge7K8KrOLeWnSEVzzKfdCK0cI9TAWJJd5eqcQ,5550
31
+ tensorneko/evaluation/fid.py,sha256=fNuE1CEp2rPXbaZfI0E1CspluInzFlUdKc8XZEexUME,5568
32
32
  tensorneko/evaluation/iou.py,sha256=phEmOWQ3cnWW377WeSHCoB8mGkHLHMHCl8_LL0IX3JA,2914
33
33
  tensorneko/evaluation/psnr.py,sha256=DeKxvY_xxawWMXHY0z3Nvbsi4dR57OUV4hjtUoCINXc,3757
34
34
  tensorneko/evaluation/secs.py,sha256=D710GgcSxQgbGyPcWlC5ffF5n1GselLrUr5aA5Vq7oE,1622
35
35
  tensorneko/evaluation/ssim.py,sha256=6vPS4VQqoKxHOG49lChH51KxwNo07B4XHdhLub5DEPU,3758
36
36
  tensorneko/io/__init__.py,sha256=QEyA0mOC-BlKKskYYbDYttYWWRjCeh73lX-yKAUGNik,213
37
- tensorneko/io/reader.py,sha256=KB4xpdHKaqtEQXj2EOVB21Ev3ODPiQZFjNadZOipCMU,705
38
- tensorneko/io/writer.py,sha256=BHNtzROUY3AImx1QwVxbtZXuxMIfQq3WUI5PU1jeCpM,708
37
+ tensorneko/io/reader.py,sha256=MSEfmzbRTk69qawTPrYruaKvQ8TXSyu5ZmFY7xVn-aU,773
38
+ tensorneko/io/writer.py,sha256=MbO878ob24WSGKqjxyK-yRqp7xAiJBEtMxOGkOg_vOE,776
39
39
  tensorneko/io/mesh/__init__.py,sha256=cdR5QWNUgPaoU_fFcJO9sx7PeJy7pTlEvusjaivP1ok,72
40
40
  tensorneko/io/mesh/mesh_reader.py,sha256=ErUv9nBMARu-eR-uHlEhrb4bH0yl0cHiKoF9GSJ569A,184
41
41
  tensorneko/io/mesh/mesh_writer.py,sha256=d_lBhN2JEhaY79mwZALr-ylp0ZtyJmItCTUn_A3F4q0,184
42
+ tensorneko/io/weight/__init__.py,sha256=zlUTKTYL7uhOxgyR3VpoR408pYjG0XHOGihGDL6mHyc,79
43
+ tensorneko/io/weight/weight_reader.py,sha256=iDtphQsPQXzfKU_xaxvHJUEfealXHIcyoD2adXLpgg0,3079
44
+ tensorneko/io/weight/weight_writer.py,sha256=06SCNPw3vDQ-6MLSpb4oR1AuOQ1slSpOextLoAl4PTA,1609
42
45
  tensorneko/layer/__init__.py,sha256=HHCBzwR-8UEVsrUKz8j_dhQSVGRASg0flLc_pyG1JN0,1120
43
46
  tensorneko/layer/aggregation.py,sha256=ykH6u-NLJx4Yesu_BLWa6T-vWIYFzJXJV1Txrbz3mPE,1177
44
47
  tensorneko/layer/attention.py,sha256=sNC6gZgZaeHqUAfzrPh0Vefp7T6nc1gFvrreuOL7wUg,6360
@@ -47,7 +50,7 @@ tensorneko/layer/conv.py,sha256=c0XLcf2hRvyv6hMj4irIAYM0JyPzpv7tJBIW0Dl_sw8,5904
47
50
  tensorneko/layer/linear.py,sha256=NKIzSmSodk0wftu-f-6aSPazJ3X-eduRkZdLMMBjN_8,3283
48
51
  tensorneko/layer/log.py,sha256=UN7xgfzC7tF4P856taRtVgVHjc5CTTZqXYODvw6DhKQ,716
49
52
  tensorneko/layer/masked_conv2d.py,sha256=4DWjQ_Wc8lncYfk2Iur_M5V0mByvaQ2JZlOGRm29XLE,4796
50
- tensorneko/layer/noise.py,sha256=rc2Q1V1tT4KUszecOs82NeffrmDPOT9gELumo_tgdrU,1281
53
+ tensorneko/layer/noise.py,sha256=bbZyrIZUk41cKBuCj3BWiQ_XTIRY9YXyDrlBavQh2Lo,1269
51
54
  tensorneko/layer/patching.py,sha256=FYwqVXbOtiLgT5Cy63zISK9bTzMIKG4U8hO-n14liiM,6585
52
55
  tensorneko/layer/positional_embedding.py,sha256=2_r1sZQPYNEdBVLDByNV85gZDofuV8IZSIjSNoKJKhw,4283
53
56
  tensorneko/layer/reshape.py,sha256=7GmpjQkir2eW2CkWuiUK648C4kKHUV81EqNiCzk2QSY,813
@@ -75,14 +78,14 @@ tensorneko/util/dispatched_misc.py,sha256=_0Go7XezdYB7bpMnCs1MDD_6mPNoWP5qt8DoKu
75
78
  tensorneko/util/misc.py,sha256=LEvACtGDOX43iK86A8-Cek0S9rbXFR0AtTP1edE3XDI,4701
76
79
  tensorneko/util/reproducibility.py,sha256=sw1vVi7VOnmzQYUocI5x9yKeZoHHiA4A5ja136XolrI,2102
77
80
  tensorneko/util/string_getter.py,sha256=Cq2mDYr3q758xJ9OBTwLDf-b6EMSYwlnNB0-kfsElfs,2491
78
- tensorneko/util/type.py,sha256=IaLpRQ5l8Ci6FZaGRohIb1ygrnJ3NTalomxDbhz68VM,716
81
+ tensorneko/util/type.py,sha256=-dWknUu7RM4pFm7f3spgSlCIXpHA11RS36ol8uUFJgU,728
79
82
  tensorneko/visualization/__init__.py,sha256=PuNMhLz3oosY39AmKUr0biIgjfc_G_rQzp960me08Fg,626
80
83
  tensorneko/visualization/log_graph.py,sha256=NvOwWVc_petXWYdgaHosPFLa43sHBeacbYcfNtdRQg4,1511
81
84
  tensorneko/visualization/matplotlib.py,sha256=xs9Ssc44ojZX65QU8-fftA7Ug_pBuZ3TBtM8vETNq9w,1568
82
85
  tensorneko/visualization/image_browser/__init__.py,sha256=AtykhAE3bXQS6SOWbeYFeeUE9ts9XOFMvrL31z0LoMg,63
83
86
  tensorneko/visualization/watcher/__init__.py,sha256=Nq752qIYvfRUZ8VctKQRSqhxh5KmFbWcqPfZlijVx6s,379
84
- tensorneko-0.3.7.dist-info/LICENSE,sha256=Vd75kwgJpVuMnCRBWasQzceMlXt4YQL13ikBLy8G5h0,1067
85
- tensorneko-0.3.7.dist-info/METADATA,sha256=877cbLhSBb-olKLWCN3lFnjfZ3i3_aHEVCRv-oTnWRI,18892
86
- tensorneko-0.3.7.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
87
- tensorneko-0.3.7.dist-info/top_level.txt,sha256=sZHwlP0iyk7_zHuhRHzSBkdY9yEgyC48f6UVuZ6CvqE,11
88
- tensorneko-0.3.7.dist-info/RECORD,,
87
+ tensorneko-0.3.9.dist-info/LICENSE,sha256=Vd75kwgJpVuMnCRBWasQzceMlXt4YQL13ikBLy8G5h0,1067
88
+ tensorneko-0.3.9.dist-info/METADATA,sha256=QIW5lNPHf27Np7FwVhn3HoWeD-AQCe_uid7RApT9R-k,18892
89
+ tensorneko-0.3.9.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
90
+ tensorneko-0.3.9.dist-info/top_level.txt,sha256=sZHwlP0iyk7_zHuhRHzSBkdY9yEgyC48f6UVuZ6CvqE,11
91
+ tensorneko-0.3.9.dist-info/RECORD,,