tensorneko 0.3.8__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tensorneko/__init__.py CHANGED
@@ -1,11 +1,11 @@
1
1
  import os.path
2
2
 
3
- from tensorneko_util import io
4
3
  from . import backend
5
4
  from . import callback
6
5
  from . import dataset
7
6
  from . import debug
8
7
  from . import evaluation
8
+ from . import io
9
9
  from . import layer
10
10
  from . import module
11
11
  from . import notebook
@@ -13,10 +13,13 @@ from . import optim
13
13
  from . import preprocess
14
14
  from . import util
15
15
  from . import visualization
16
+ from .io import read, write
16
17
  from .neko_model import NekoModel
17
18
  from .neko_module import NekoModule
18
19
  from .neko_trainer import NekoTrainer
19
20
 
21
+ __version__ = io.read.text(os.path.join(util.get_tensorneko_path(), "version.txt"))
22
+
20
23
  __all__ = [
21
24
  "callback",
22
25
  "dataset",
@@ -33,7 +36,9 @@ __all__ = [
33
36
  "debug",
34
37
  "NekoModel",
35
38
  "NekoTrainer",
36
- "NekoModule"
39
+ "NekoModule",
40
+ "read",
41
+ "write",
37
42
  ]
38
43
 
39
- __version__ = io.read.text(os.path.join(util.get_tensorneko_path(), "version.txt"))
44
+
@@ -1,8 +1,10 @@
1
- from tensorneko_util.backend import parallel, run_blocking, VisualLib, AudioLib
1
+ from tensorneko_util.backend import parallel, run_blocking, VisualLib, AudioLib, import_tqdm_auto, import_tqdm
2
2
 
3
3
  __all__ = [
4
4
  "parallel",
5
5
  "run_blocking",
6
6
  "VisualLib",
7
7
  "AudioLib",
8
+ "import_tqdm_auto",
9
+ "import_tqdm",
8
10
  ]
@@ -12,12 +12,13 @@ from torchmetrics.image.fid import FrechetInceptionDistance
12
12
  from torchvision.transforms.functional import resize
13
13
 
14
14
  from tensorneko_util.backend import VisualLib
15
- from tensorneko_util.backend._tqdm import import_tqdm_auto
15
+ from tensorneko_util.backend.tqdm import import_tqdm_auto
16
16
 
17
17
  try:
18
18
  from typing import Literal
19
19
  TypeOption = Literal["video", "image"]
20
20
  except ImportError:
21
+ Literal = None
21
22
  TypeOption = str
22
23
 
23
24
 
tensorneko/io/reader.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  from typing import Type
3
3
 
4
4
  from tensorneko_util.io.reader import Reader as BaseReader
5
+ from .weight import WeightReader
5
6
 
6
7
  try:
7
8
  from .mesh import MeshReader
@@ -15,6 +16,7 @@ class Reader(BaseReader):
15
16
 
16
17
  def __init__(self):
17
18
  super().__init__()
19
+ self.weight = WeightReader
18
20
  self._mesh = None
19
21
 
20
22
  @property
@@ -0,0 +1,2 @@
1
+ from .weight_reader import WeightReader
2
+ from .weight_writer import WeightWriter
@@ -0,0 +1,81 @@
1
+ from typing import OrderedDict
2
+
3
+ import torch
4
+
5
+ from ...util import Device
6
+
7
+
8
+ class WeightReader:
9
+ """WeightReader for read model weights (checkpoints, state_dict, etc)."""
10
+
11
+ @classmethod
12
+ def of_pt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
13
+ """
14
+ Reads PyTorch model weights from a `.pt` or `.pth` file.
15
+
16
+ Args:
17
+ path (``str``): The path of the `.pt` or `.pth` file.
18
+ map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
19
+
20
+ Returns:
21
+ :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
22
+ """
23
+ return torch.load(path, map_location=map_location)
24
+
25
+ @classmethod
26
+ def of_ckpt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
27
+ """
28
+ Reads PyTorch model weights from a `.ckpt` file.
29
+
30
+ Args:
31
+ path (``str``): The path of the `.ckpt` file.
32
+ map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
33
+
34
+ Returns:
35
+ :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
36
+ """
37
+ return torch.load(path, map_location=map_location)["state_dict"]
38
+
39
+ @classmethod
40
+ def of_safetensors(cls, path: str, map_location: str = "cpu") -> OrderedDict[str, torch.Tensor]:
41
+ """
42
+ Reads model weights from a `.safetensors` file.
43
+
44
+ Args:
45
+ path (``str``): The path of the `.safetensors` file.
46
+ map_location (``str``): The location to load the model weights. Default: "cpu"
47
+
48
+ Returns:
49
+ :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
50
+ """
51
+ import safetensors
52
+ from collections import OrderedDict
53
+ tensors = OrderedDict()
54
+ with safetensors.safe_open(path, framework="pt", device=map_location) as f:
55
+ for key in f.keys():
56
+ tensors[key] = f.get_tensor(key)
57
+ return tensors
58
+
59
+ @classmethod
60
+ def of(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
61
+ """
62
+ Reads model weights from a file.
63
+
64
+ Args:
65
+ path (``str``): The path of the file.
66
+ map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
67
+
68
+ Returns:
69
+ :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
70
+ """
71
+
72
+ if path.endswith(".pt") or path.endswith(".pth"):
73
+ return cls.of_pt(path, map_location)
74
+ elif path.endswith(".ckpt"):
75
+ return cls.of_ckpt(path, map_location)
76
+ elif path.endswith(".safetensors"):
77
+ if isinstance(map_location, torch.device):
78
+ map_location = str(map_location)
79
+ return cls.of_safetensors(path, map_location)
80
+ else:
81
+ raise ValueError("Unknown file type. Supported types: .pt, .pth, .ckpt, .safetensors")
@@ -0,0 +1,48 @@
1
+ from typing import Dict
2
+
3
+ import torch
4
+
5
+
6
+ class WeightWriter:
7
+ """WeightWriter for write model weights (checkpoints, state_dict, etc)."""
8
+
9
+ @classmethod
10
+ def to_pt(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
11
+ """
12
+ Writes PyTorch model weights to a `.pt` or `.pth` file.
13
+
14
+ Args:
15
+ path (``str``): The path of the `.pt` or `.pth` file.
16
+ weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
17
+ """
18
+ torch.save(weights, path)
19
+
20
+ @classmethod
21
+ def to_safetensors(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
22
+ """
23
+ Writes model weights to a `.safetensors` file.
24
+
25
+ Args:
26
+ path (``str``): The path of the `.safetensors` file.
27
+ weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
28
+ """
29
+ import safetensors.torch
30
+ safetensors.torch.save_file(weights, path)
31
+
32
+ @classmethod
33
+ def to(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
34
+ """
35
+ Writes model weights to a file.
36
+
37
+ Args:
38
+ path (``str``): The path of the file.
39
+ weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
40
+ """
41
+ file_type = path.split(".")[-1]
42
+
43
+ if file_type == "pt":
44
+ cls.to_pt(path, weights)
45
+ elif file_type == "safetensors":
46
+ cls.to_safetensors(path, weights)
47
+ else:
48
+ raise ValueError("Unknown file type. Supported types: .pt, .safetensors")
tensorneko/io/writer.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from typing import Type
4
4
 
5
5
  from tensorneko_util.io.writer import Writer as BaseWriter
6
+ from .weight import WeightWriter
6
7
 
7
8
  try:
8
9
  from .mesh import MeshWriter
@@ -16,6 +17,7 @@ class Writer(BaseWriter):
16
17
 
17
18
  def __init__(self):
18
19
  super().__init__()
20
+ self.weight = WeightWriter
19
21
  self._mesh = None
20
22
 
21
23
  @property
tensorneko/layer/noise.py CHANGED
@@ -22,7 +22,7 @@ class GaussianNoise(NekoModule):
22
22
  from https://discuss.pytorch.org/t/writing-a-simple-gaussian-noise-layer-in-pytorch/4694/3
23
23
  """
24
24
 
25
- def __init__(self, sigma=0.1, device: Union[Device, str] = "cuda"):
25
+ def __init__(self, sigma=0.1, device: Device = "cuda"):
26
26
  super().__init__()
27
27
  self.sigma = sigma
28
28
  self.noise = torch.tensor(0.).to(device)
tensorneko/util/type.py CHANGED
@@ -11,7 +11,7 @@ ModuleFactory = Union[Callable[[], Module], Callable[[int], Module]]
11
11
  """The module builder type of ``() -> torch.nn.Module | (int) -> torch.nn.Module``"""
12
12
 
13
13
 
14
- Device = device
14
+ Device = Union[str, device]
15
15
  """Device type of :class:`torch.device`"""
16
16
 
17
17
 
tensorneko/version.txt CHANGED
@@ -1 +1 @@
1
- 0.3.8
1
+ 0.3.9
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tensorneko
3
- Version: 0.3.8
3
+ Version: 0.3.9
4
4
  Summary: Tensor Neural Engine Kompanion. An util library based on PyTorch and PyTorch Lightning.
5
5
  Home-page: https://github.com/ControlNet/tensorneko
6
6
  Author: ControlNet
@@ -33,7 +33,7 @@ Requires-Dist: pillow >=8.1
33
33
  Requires-Dist: av >=8.0.3
34
34
  Requires-Dist: numpy >=1.20.1
35
35
  Requires-Dist: einops >=0.3.0
36
- Requires-Dist: tensorneko-util ==0.3.8
36
+ Requires-Dist: tensorneko-util ==0.3.9
37
37
  Requires-Dist: pysoundfile >=0.9.0 ; platform_system == "Windows"
38
38
 
39
39
  <h1 style="text-align: center">TensorNeko</h1>
@@ -1,15 +1,15 @@
1
- tensorneko/__init__.py,sha256=VPPK00Kduwi84QHnZKBZm8kBRdnPAji6f7J-adYAp_Y,770
1
+ tensorneko/__init__.py,sha256=lLVC4StQ5q9OFJTceNrJbj-CIFHFMRvFQyGYJTjuRM4,812
2
2
  tensorneko/neko_model.py,sha256=QTbdOAg9ki0ix6mDL_Qu8Wmd5WJOoUFF3M1SXEp3KGc,10551
3
3
  tensorneko/neko_module.py,sha256=qELXvguSjWo_NvcRQibiFl0Qauzd9JWLSnT4dbGNS3Y,1473
4
4
  tensorneko/neko_trainer.py,sha256=JC8qoKSZ5ngz3grf3S0SjvIFVktDIP_GExth5aFfbGA,10074
5
- tensorneko/version.txt,sha256=L_j9rGHiSDgBzj7QgRwf9tGYLHGVnIvOsy9NcC-vNsU,5
5
+ tensorneko/version.txt,sha256=JCuzYNF_l66JsKuMd57Z4WJC2x1-36y0IfHrRwznDqM,5
6
6
  tensorneko/arch/__init__.py,sha256=w4lTUeyBIZelrnSjlBFWUF0erzOmBFl9FqeWQuSOyKs,248
7
7
  tensorneko/arch/auto_encoder.py,sha256=j6PWWyaNYaYNtw_zZ9ikzhCASqe9viXR3JGBIXSK92Y,2137
8
8
  tensorneko/arch/binary_classifier.py,sha256=1MkEbReXKLdDksRG5Rsife40grJk08EVDcNKp54Xvb4,2316
9
9
  tensorneko/arch/gan.py,sha256=bHh8s9UxHNSj9N6C6SllAaVJ61JhbY4ltwmQ-a1543w,7215
10
10
  tensorneko/arch/vqvae.py,sha256=_HJoDVBonB4PHymaV9JaAecS-NJZ1jfrKAhoqcLzf0g,3750
11
11
  tensorneko/arch/wgan.py,sha256=Oj1vZPXzhgTPHGGXGmbStgylskjjzulr6tYxj7E0zpQ,4722
12
- tensorneko/backend/__init__.py,sha256=hWEoQo3Bh7Fn7cfaJznA3k1sHWAUIJ9KAbDxOVjySxU,164
12
+ tensorneko/backend/__init__.py,sha256=ppJhb1MC_WK6XLK2fX8x0z7Cn-8CkF97x67XBLoRCys,238
13
13
  tensorneko/callback/__init__.py,sha256=H1jOTsSYm9c4sxtcV9_uzumXZ95b4gbiftKqimNER7s,556
14
14
  tensorneko/callback/display_metrics_callback.py,sha256=qzhHcb68B7o9byfD1ZqEitSVkrwsSGOF-u59_Ip9dEg,318
15
15
  tensorneko/callback/earlystop_lr.py,sha256=d0G3NHYi-tNNCdaB7Rt7vVceI8CXOGT4zBTJYNciyug,1195
@@ -28,17 +28,20 @@ tensorneko/dataset/sampler/sequential_iter_sampler.py,sha256=DxBwSoWjYlq6kA6g-54
28
28
  tensorneko/debug/__init__.py,sha256=ZMfU3qquhMhl6EgPzM7Yuvvv0PWy3cR39UjPrrSmQcs,163
29
29
  tensorneko/evaluation/__init__.py,sha256=jW8dh1JRMpx3npjTp7wJLzz-IxFZTBh7F-Ztfoep9xs,296
30
30
  tensorneko/evaluation/enum.py,sha256=s3P8XAobku-as4in5vh6BanvVW5Ccwnff0t124lVFFg,137
31
- tensorneko/evaluation/fid.py,sha256=mDsgh7Ge7K8KrOLeWnSEVzzKfdCK0cI9TAWJJd5eqcQ,5550
31
+ tensorneko/evaluation/fid.py,sha256=fNuE1CEp2rPXbaZfI0E1CspluInzFlUdKc8XZEexUME,5568
32
32
  tensorneko/evaluation/iou.py,sha256=phEmOWQ3cnWW377WeSHCoB8mGkHLHMHCl8_LL0IX3JA,2914
33
33
  tensorneko/evaluation/psnr.py,sha256=DeKxvY_xxawWMXHY0z3Nvbsi4dR57OUV4hjtUoCINXc,3757
34
34
  tensorneko/evaluation/secs.py,sha256=D710GgcSxQgbGyPcWlC5ffF5n1GselLrUr5aA5Vq7oE,1622
35
35
  tensorneko/evaluation/ssim.py,sha256=6vPS4VQqoKxHOG49lChH51KxwNo07B4XHdhLub5DEPU,3758
36
36
  tensorneko/io/__init__.py,sha256=QEyA0mOC-BlKKskYYbDYttYWWRjCeh73lX-yKAUGNik,213
37
- tensorneko/io/reader.py,sha256=KB4xpdHKaqtEQXj2EOVB21Ev3ODPiQZFjNadZOipCMU,705
38
- tensorneko/io/writer.py,sha256=BHNtzROUY3AImx1QwVxbtZXuxMIfQq3WUI5PU1jeCpM,708
37
+ tensorneko/io/reader.py,sha256=MSEfmzbRTk69qawTPrYruaKvQ8TXSyu5ZmFY7xVn-aU,773
38
+ tensorneko/io/writer.py,sha256=MbO878ob24WSGKqjxyK-yRqp7xAiJBEtMxOGkOg_vOE,776
39
39
  tensorneko/io/mesh/__init__.py,sha256=cdR5QWNUgPaoU_fFcJO9sx7PeJy7pTlEvusjaivP1ok,72
40
40
  tensorneko/io/mesh/mesh_reader.py,sha256=ErUv9nBMARu-eR-uHlEhrb4bH0yl0cHiKoF9GSJ569A,184
41
41
  tensorneko/io/mesh/mesh_writer.py,sha256=d_lBhN2JEhaY79mwZALr-ylp0ZtyJmItCTUn_A3F4q0,184
42
+ tensorneko/io/weight/__init__.py,sha256=zlUTKTYL7uhOxgyR3VpoR408pYjG0XHOGihGDL6mHyc,79
43
+ tensorneko/io/weight/weight_reader.py,sha256=iDtphQsPQXzfKU_xaxvHJUEfealXHIcyoD2adXLpgg0,3079
44
+ tensorneko/io/weight/weight_writer.py,sha256=06SCNPw3vDQ-6MLSpb4oR1AuOQ1slSpOextLoAl4PTA,1609
42
45
  tensorneko/layer/__init__.py,sha256=HHCBzwR-8UEVsrUKz8j_dhQSVGRASg0flLc_pyG1JN0,1120
43
46
  tensorneko/layer/aggregation.py,sha256=ykH6u-NLJx4Yesu_BLWa6T-vWIYFzJXJV1Txrbz3mPE,1177
44
47
  tensorneko/layer/attention.py,sha256=sNC6gZgZaeHqUAfzrPh0Vefp7T6nc1gFvrreuOL7wUg,6360
@@ -47,7 +50,7 @@ tensorneko/layer/conv.py,sha256=c0XLcf2hRvyv6hMj4irIAYM0JyPzpv7tJBIW0Dl_sw8,5904
47
50
  tensorneko/layer/linear.py,sha256=NKIzSmSodk0wftu-f-6aSPazJ3X-eduRkZdLMMBjN_8,3283
48
51
  tensorneko/layer/log.py,sha256=UN7xgfzC7tF4P856taRtVgVHjc5CTTZqXYODvw6DhKQ,716
49
52
  tensorneko/layer/masked_conv2d.py,sha256=4DWjQ_Wc8lncYfk2Iur_M5V0mByvaQ2JZlOGRm29XLE,4796
50
- tensorneko/layer/noise.py,sha256=rc2Q1V1tT4KUszecOs82NeffrmDPOT9gELumo_tgdrU,1281
53
+ tensorneko/layer/noise.py,sha256=bbZyrIZUk41cKBuCj3BWiQ_XTIRY9YXyDrlBavQh2Lo,1269
51
54
  tensorneko/layer/patching.py,sha256=FYwqVXbOtiLgT5Cy63zISK9bTzMIKG4U8hO-n14liiM,6585
52
55
  tensorneko/layer/positional_embedding.py,sha256=2_r1sZQPYNEdBVLDByNV85gZDofuV8IZSIjSNoKJKhw,4283
53
56
  tensorneko/layer/reshape.py,sha256=7GmpjQkir2eW2CkWuiUK648C4kKHUV81EqNiCzk2QSY,813
@@ -75,14 +78,14 @@ tensorneko/util/dispatched_misc.py,sha256=_0Go7XezdYB7bpMnCs1MDD_6mPNoWP5qt8DoKu
75
78
  tensorneko/util/misc.py,sha256=LEvACtGDOX43iK86A8-Cek0S9rbXFR0AtTP1edE3XDI,4701
76
79
  tensorneko/util/reproducibility.py,sha256=sw1vVi7VOnmzQYUocI5x9yKeZoHHiA4A5ja136XolrI,2102
77
80
  tensorneko/util/string_getter.py,sha256=Cq2mDYr3q758xJ9OBTwLDf-b6EMSYwlnNB0-kfsElfs,2491
78
- tensorneko/util/type.py,sha256=IaLpRQ5l8Ci6FZaGRohIb1ygrnJ3NTalomxDbhz68VM,716
81
+ tensorneko/util/type.py,sha256=-dWknUu7RM4pFm7f3spgSlCIXpHA11RS36ol8uUFJgU,728
79
82
  tensorneko/visualization/__init__.py,sha256=PuNMhLz3oosY39AmKUr0biIgjfc_G_rQzp960me08Fg,626
80
83
  tensorneko/visualization/log_graph.py,sha256=NvOwWVc_petXWYdgaHosPFLa43sHBeacbYcfNtdRQg4,1511
81
84
  tensorneko/visualization/matplotlib.py,sha256=xs9Ssc44ojZX65QU8-fftA7Ug_pBuZ3TBtM8vETNq9w,1568
82
85
  tensorneko/visualization/image_browser/__init__.py,sha256=AtykhAE3bXQS6SOWbeYFeeUE9ts9XOFMvrL31z0LoMg,63
83
86
  tensorneko/visualization/watcher/__init__.py,sha256=Nq752qIYvfRUZ8VctKQRSqhxh5KmFbWcqPfZlijVx6s,379
84
- tensorneko-0.3.8.dist-info/LICENSE,sha256=Vd75kwgJpVuMnCRBWasQzceMlXt4YQL13ikBLy8G5h0,1067
85
- tensorneko-0.3.8.dist-info/METADATA,sha256=1i0yJukLciQ3ASW4rpphTBK_dEcTtnWpEn64DTCpk1s,18892
86
- tensorneko-0.3.8.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
87
- tensorneko-0.3.8.dist-info/top_level.txt,sha256=sZHwlP0iyk7_zHuhRHzSBkdY9yEgyC48f6UVuZ6CvqE,11
88
- tensorneko-0.3.8.dist-info/RECORD,,
87
+ tensorneko-0.3.9.dist-info/LICENSE,sha256=Vd75kwgJpVuMnCRBWasQzceMlXt4YQL13ikBLy8G5h0,1067
88
+ tensorneko-0.3.9.dist-info/METADATA,sha256=QIW5lNPHf27Np7FwVhn3HoWeD-AQCe_uid7RApT9R-k,18892
89
+ tensorneko-0.3.9.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
90
+ tensorneko-0.3.9.dist-info/top_level.txt,sha256=sZHwlP0iyk7_zHuhRHzSBkdY9yEgyC48f6UVuZ6CvqE,11
91
+ tensorneko-0.3.9.dist-info/RECORD,,