tensorneko 0.3.17__py3-none-any.whl → 0.3.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tensorneko/io/reader.py CHANGED
@@ -1,6 +1,9 @@
1
1
  from __future__ import annotations
2
- from typing import Type
3
2
 
3
+ from pathlib import Path
4
+ from typing import Type, Union, Any
5
+
6
+ from tensorneko_util.io._path_conversion import _path2str
4
7
  from tensorneko_util.io.reader import Reader as BaseReader
5
8
  from .weight import WeightReader
6
9
 
@@ -28,3 +31,11 @@ class Reader(BaseReader):
28
31
  return self._mesh
29
32
  else:
30
33
  raise ImportError("To use the mesh reader, please install pytorch3d.")
34
+
35
+ def __call__(self, path: Union[str, Path], *args, **kwargs) -> Any:
36
+ """Automatically infer the file type and return the corresponding result. """
37
+ path = _path2str(path)
38
+ if path.endswith(".pt") or path.endswith(".pth") or path.endswith(".ckpt") or path.endswith(".safetensors"):
39
+ return self.weight(path, *args, **kwargs)
40
+ else:
41
+ return super().__call__(path, *args, **kwargs)
@@ -1,7 +1,9 @@
1
- from typing import OrderedDict
1
+ from typing import OrderedDict, Union
2
+ from pathlib import Path
2
3
 
3
4
  import torch
4
5
 
6
+ from tensorneko_util.io._path_conversion import _path2str
5
7
  from ...util import Device
6
8
 
7
9
 
@@ -9,45 +11,48 @@ class WeightReader:
9
11
  """WeightReader for read model weights (checkpoints, state_dict, etc)."""
10
12
 
11
13
  @classmethod
12
- def of_pt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
14
+ def of_pt(cls, path: Union[str, Path], map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
13
15
  """
14
16
  Reads PyTorch model weights from a `.pt` or `.pth` file.
15
17
 
16
18
  Args:
17
- path (``str``): The path of the `.pt` or `.pth` file.
19
+ path (``str`` | ``pathlib.Path``): The path of the `.pt` or `.pth` file.
18
20
  map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
19
21
 
20
22
  Returns:
21
23
  :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
22
24
  """
25
+ path = _path2str(path)
23
26
  return torch.load(path, map_location=map_location)
24
27
 
25
28
  @classmethod
26
- def of_ckpt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
29
+ def of_ckpt(cls, path: Union[str, Path], map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
27
30
  """
28
31
  Reads PyTorch model weights from a `.ckpt` file.
29
32
 
30
33
  Args:
31
- path (``str``): The path of the `.ckpt` file.
34
+ path (``str`` | ``pathlib.Path``): The path of the `.ckpt` file.
32
35
  map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
33
36
 
34
37
  Returns:
35
38
  :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
36
39
  """
40
+ path = _path2str(path)
37
41
  return torch.load(path, map_location=map_location)["state_dict"]
38
42
 
39
43
  @classmethod
40
- def of_safetensors(cls, path: str, map_location: str = "cpu") -> OrderedDict[str, torch.Tensor]:
44
+ def of_safetensors(cls, path: Union[str, Path], map_location: str = "cpu") -> OrderedDict[str, torch.Tensor]:
41
45
  """
42
46
  Reads model weights from a `.safetensors` file.
43
47
 
44
48
  Args:
45
- path (``str``): The path of the `.safetensors` file.
49
+ path (``str`` | ``pathlib.Path``): The path of the `.safetensors` file.
46
50
  map_location (``str``): The location to load the model weights. Default: "cpu"
47
51
 
48
52
  Returns:
49
53
  :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
50
54
  """
55
+ path = _path2str(path)
51
56
  import safetensors
52
57
  from collections import OrderedDict
53
58
  tensors = OrderedDict()
@@ -57,18 +62,18 @@ class WeightReader:
57
62
  return tensors
58
63
 
59
64
  @classmethod
60
- def of(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
65
+ def of(cls, path: Union[str, Path], map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
61
66
  """
62
67
  Reads model weights from a file.
63
68
 
64
69
  Args:
65
- path (``str``): The path of the file.
70
+ path (``str`` | ``pathlib.Path``): The path of the file.
66
71
  map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
67
72
 
68
73
  Returns:
69
74
  :class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
70
75
  """
71
-
76
+ path = _path2str(path)
72
77
  if path.endswith(".pt") or path.endswith(".pth"):
73
78
  return cls.of_pt(path, map_location)
74
79
  elif path.endswith(".ckpt"):
@@ -79,3 +84,8 @@ class WeightReader:
79
84
  return cls.of_safetensors(path, map_location)
80
85
  else:
81
86
  raise ValueError("Unknown file type. Supported types: .pt, .pth, .ckpt, .safetensors")
87
+
88
+ def __new__(cls, path: Union[str, Path], map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
89
+ """Alias of :meth:`~tensorneko.io.weight.weight_reader.WeightReader.of`."""
90
+ path = _path2str(path)
91
+ return cls.of(path, map_location)
@@ -1,48 +1,59 @@
1
- from typing import Dict
1
+ from typing import Dict, Union
2
+ from pathlib import Path
2
3
 
3
4
  import torch
4
5
 
6
+ from tensorneko_util.io._path_conversion import _path2str
7
+ from ...util import Device
5
8
 
6
9
  class WeightWriter:
7
10
  """WeightWriter for write model weights (checkpoints, state_dict, etc)."""
8
11
 
9
12
  @classmethod
10
- def to_pt(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
13
+ def to_pt(cls, path: Union[str, Path], weights: Dict[str, torch.Tensor]) -> None:
11
14
  """
12
15
  Writes PyTorch model weights to a `.pt` or `.pth` file.
13
16
 
14
17
  Args:
15
- path (``str``): The path of the `.pt` or `.pth` file.
18
+ path (``str`` | ``pathlib.Path``): The path of the `.pt` or `.pth` file.
16
19
  weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
17
20
  """
21
+ path = _path2str(path)
18
22
  torch.save(weights, path)
19
23
 
20
24
  @classmethod
21
- def to_safetensors(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
25
+ def to_safetensors(cls, path: Union[str, Path], weights: Dict[str, torch.Tensor]) -> None:
22
26
  """
23
27
  Writes model weights to a `.safetensors` file.
24
28
 
25
29
  Args:
26
- path (``str``): The path of the `.safetensors` file.
30
+ path (``str`` | ``pathlib.Path``): The path of the `.safetensors` file.
27
31
  weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
28
32
  """
33
+ path = _path2str(path)
29
34
  import safetensors.torch
30
35
  safetensors.torch.save_file(weights, path)
31
36
 
32
37
  @classmethod
33
- def to(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
38
+ def to(cls, path: Union[str, Path], weights: Dict[str, torch.Tensor]) -> None:
34
39
  """
35
40
  Writes model weights to a file.
36
41
 
37
42
  Args:
38
- path (``str``): The path of the file.
43
+ path (``str`` | ``pathlib.Path``): The path of the file.
39
44
  weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
40
45
  """
46
+ path = _path2str(path)
41
47
  file_type = path.split(".")[-1]
42
48
 
43
- if file_type == "pt":
49
+ if file_type in ("pt", "pth"):
44
50
  cls.to_pt(path, weights)
45
51
  elif file_type == "safetensors":
46
52
  cls.to_safetensors(path, weights)
47
53
  else:
48
54
  raise ValueError("Unknown file type. Supported types: .pt, .safetensors")
55
+
56
+ def __new__(cls, path: Union[str, Path], weights: Dict[str, torch.Tensor]) -> None:
57
+ """Alias of :meth:`~tensorneko.io.weight.weight_writer.WeightWriter.to`."""
58
+ path = _path2str(path)
59
+ return cls.to(path, weights)
tensorneko/io/writer.py CHANGED
@@ -1,7 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Type
3
+ from pathlib import Path
4
+ from typing import Type, Union
4
5
 
6
+ from tensorneko_util.io._path_conversion import _path2str
5
7
  from tensorneko_util.io.writer import Writer as BaseWriter
6
8
  from .weight import WeightWriter
7
9
 
@@ -29,3 +31,12 @@ class Writer(BaseWriter):
29
31
  return self._mesh
30
32
  else:
31
33
  raise ImportError("To use the mesh writer, please install pytorch3d.")
34
+
35
+ def __call__(self, path: Union[str, Path], obj, *args, **kwargs):
36
+ """Automatically infer the file type and return the corresponding result. """
37
+ path = _path2str(path)
38
+
39
+ if path.endswith(".pt") or path.endswith(".pth") or path.endswith(".ckpt") or path.endswith(".safetensors"):
40
+ return self.weight(path, obj)
41
+ else:
42
+ return super().__call__(path, obj, *args, **kwargs)
@@ -3,5 +3,5 @@ from tensorneko_util.notebook import ipython_available
3
3
  __all__ = []
4
4
 
5
5
  if ipython_available:
6
- from tensorneko_util.notebook import display
7
- __all__.append("display")
6
+ from tensorneko_util.notebook import display, animation
7
+ __all__.extend(["display", "animation"])
@@ -4,7 +4,8 @@ from tensorneko_util.util import dispatch, AverageMeter, tensorneko_util_path
4
4
  from tensorneko_util.util.fp import Seq, AbstractSeq, curry, F, Stream, return_option, Option, Monad, Eval, _, __
5
5
  from tensorneko_util.util import ref, Timer, Singleton
6
6
  from tensorneko_util.util.eventbus import Event, EventBus, EventHandler, subscribe
7
- from tensorneko_util.util import download_file, download_file_thread, download_files_thread, WindowMerger, Registry
7
+ from tensorneko_util.util import download_file, download_file_thread, download_files_thread, WindowMerger, Registry, \
8
+ MultiLayerIndexer
8
9
  from . import type
9
10
  from .configuration import Configuration
10
11
  from .misc import reduce_dict_by, summarize_dict_by, with_printed_shape, is_bad_num, count_parameters, compose, \
@@ -78,5 +79,6 @@ __all__ = [
78
79
  "download_files_thread",
79
80
  "WindowMerger",
80
81
  "Registry",
81
- "run_gc"
82
+ "run_gc",
83
+ "MultiLayerIndexer"
82
84
  ]
tensorneko/version.txt CHANGED
@@ -1 +1 @@
1
- 0.3.17
1
+ 0.3.19
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tensorneko
3
- Version: 0.3.17
3
+ Version: 0.3.19
4
4
  Summary: Tensor Neural Engine Kompanion. An util library based on PyTorch and PyTorch Lightning.
5
5
  Home-page: https://github.com/ControlNet/tensorneko
6
6
  Author: ControlNet
@@ -27,7 +27,7 @@ Requires-Dist: av (>=8.0.3)
27
27
  Requires-Dist: einops (>=0.3.0)
28
28
  Requires-Dist: numpy (>=1.20.1)
29
29
  Requires-Dist: pillow (>=8.1)
30
- Requires-Dist: tensorneko-util (==0.3.17)
30
+ Requires-Dist: tensorneko-util (==0.3.19)
31
31
  Requires-Dist: torch (>=1.9.0)
32
32
  Requires-Dist: torchaudio (>=0.9.0)
33
33
  Requires-Dist: torchmetrics (>=0.7.3)
@@ -2,7 +2,7 @@ tensorneko/__init__.py,sha256=uh1HNn1sNpX1bbOqAE_kNJfrH4eMtEzus0hO-Fh9tEw,990
2
2
  tensorneko/neko_model.py,sha256=QTbdOAg9ki0ix6mDL_Qu8Wmd5WJOoUFF3M1SXEp3KGc,10551
3
3
  tensorneko/neko_module.py,sha256=qELXvguSjWo_NvcRQibiFl0Qauzd9JWLSnT4dbGNS3Y,1473
4
4
  tensorneko/neko_trainer.py,sha256=JC8qoKSZ5ngz3grf3S0SjvIFVktDIP_GExth5aFfbGA,10074
5
- tensorneko/version.txt,sha256=sOiNkmvwnZw9dn4Od6oRM8jDNFGeaDoTmEVv7sd2w2k,6
5
+ tensorneko/version.txt,sha256=xV2o_U6zdvujab1PTmKLrZ6dQjt1fW9CyXhxWiQJBpo,6
6
6
  tensorneko/arch/__init__.py,sha256=w4lTUeyBIZelrnSjlBFWUF0erzOmBFl9FqeWQuSOyKs,248
7
7
  tensorneko/arch/auto_encoder.py,sha256=j6PWWyaNYaYNtw_zZ9ikzhCASqe9viXR3JGBIXSK92Y,2137
8
8
  tensorneko/arch/binary_classifier.py,sha256=1MkEbReXKLdDksRG5Rsife40grJk08EVDcNKp54Xvb4,2316
@@ -34,14 +34,14 @@ tensorneko/evaluation/psnr.py,sha256=DeKxvY_xxawWMXHY0z3Nvbsi4dR57OUV4hjtUoCINXc
34
34
  tensorneko/evaluation/secs.py,sha256=D710GgcSxQgbGyPcWlC5ffF5n1GselLrUr5aA5Vq7oE,1622
35
35
  tensorneko/evaluation/ssim.py,sha256=6vPS4VQqoKxHOG49lChH51KxwNo07B4XHdhLub5DEPU,3758
36
36
  tensorneko/io/__init__.py,sha256=QEyA0mOC-BlKKskYYbDYttYWWRjCeh73lX-yKAUGNik,213
37
- tensorneko/io/reader.py,sha256=MSEfmzbRTk69qawTPrYruaKvQ8TXSyu5ZmFY7xVn-aU,773
38
- tensorneko/io/writer.py,sha256=MbO878ob24WSGKqjxyK-yRqp7xAiJBEtMxOGkOg_vOE,776
37
+ tensorneko/io/reader.py,sha256=DSeTGLh84sFYwCwJmNTr-fGWkluudCbf7je29t0Z2U8,1303
38
+ tensorneko/io/writer.py,sha256=BR_1h-wXekBdctXymJBU44HoWsKxPhbbh6N3AKYNkjE,1292
39
39
  tensorneko/io/mesh/__init__.py,sha256=cdR5QWNUgPaoU_fFcJO9sx7PeJy7pTlEvusjaivP1ok,72
40
40
  tensorneko/io/mesh/mesh_reader.py,sha256=ErUv9nBMARu-eR-uHlEhrb4bH0yl0cHiKoF9GSJ569A,184
41
41
  tensorneko/io/mesh/mesh_writer.py,sha256=d_lBhN2JEhaY79mwZALr-ylp0ZtyJmItCTUn_A3F4q0,184
42
42
  tensorneko/io/weight/__init__.py,sha256=zlUTKTYL7uhOxgyR3VpoR408pYjG0XHOGihGDL6mHyc,79
43
- tensorneko/io/weight/weight_reader.py,sha256=iDtphQsPQXzfKU_xaxvHJUEfealXHIcyoD2adXLpgg0,3079
44
- tensorneko/io/weight/weight_writer.py,sha256=06SCNPw3vDQ-6MLSpb4oR1AuOQ1slSpOextLoAl4PTA,1609
43
+ tensorneko/io/weight/weight_reader.py,sha256=3W9MyULuYtY6rjBrQvJ0F5Lu0Yb0Wvha0aSebZis6eU,3688
44
+ tensorneko/io/weight/weight_writer.py,sha256=yjtLXi5negrodfirEHzbjHrnZ2ALa6GZETBc1UxA5t4,2165
45
45
  tensorneko/layer/__init__.py,sha256=HHCBzwR-8UEVsrUKz8j_dhQSVGRASg0flLc_pyG1JN0,1120
46
46
  tensorneko/layer/aggregation.py,sha256=ykH6u-NLJx4Yesu_BLWa6T-vWIYFzJXJV1Txrbz3mPE,1177
47
47
  tensorneko/layer/attention.py,sha256=sNC6gZgZaeHqUAfzrPh0Vefp7T6nc1gFvrreuOL7wUg,6360
@@ -64,7 +64,7 @@ tensorneko/module/mlp.py,sha256=AFN6xmvlrNWOflLqVl-zVkoOJRZpYxYB4bnI10JG5CU,3361
64
64
  tensorneko/module/residual.py,sha256=S59TqiiD_310HQ3a6s3r49XY_7Dc4RGxONQtSvzEfN0,2958
65
65
  tensorneko/module/transformer.py,sha256=h4NvH3zGa0rZt0bv6e8VM31SimbQKRcocSR42zJYVoY,7602
66
66
  tensorneko/msg/__init__.py,sha256=GHrHjzw__0DcPBHBN6GzHrD8PD_7CwOZPRullOaZyW0,71
67
- tensorneko/notebook/__init__.py,sha256=Pgz4aTJg5_3zTzBIaML4LHAMOFgrBOHw3o00ZuvNuQU,171
67
+ tensorneko/notebook/__init__.py,sha256=4cCi3ZyaX48hLDvJQqW0G3a4z_vdzmh_jtJ-Jzil4SM,197
68
68
  tensorneko/optim/__init__.py,sha256=89XjYQICij8SkrW5iryfZgmbxcTDxA3hhVZTgR4588o,33
69
69
  tensorneko/optim/lr_scheduler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
70
70
  tensorneko/preprocess/__init__.py,sha256=0Z0eA3_I2wphyyZlzZYRrx2muWTF0QMFq2Y-jh8oVKU,808
@@ -73,7 +73,7 @@ tensorneko/preprocess/enum.py,sha256=Wp5qFaUjea5XU4o3N0WxUd-qfzI-m5vr4ZWSqWjELb4
73
73
  tensorneko/preprocess/pad.py,sha256=b4IbbhGNRotZ7weZcKA7hfDqSixPo5KjM6khnqzaeUA,3238
74
74
  tensorneko/preprocess/resize.py,sha256=hitMlzVnN6n_8nEJwxy4C4ErZrTwpM86QGnYewsrmf8,3469
75
75
  tensorneko/preprocess/face_detector/__init__.py,sha256=_ktIfUZqGTX0hk7RBgKf-zHwG2n9KRH4RS7rjuOI8Bo,262
76
- tensorneko/util/__init__.py,sha256=1ygIyTqZm1ROCsgQJ9_df5TOezyHx08stMDGjtyMxhE,2208
76
+ tensorneko/util/__init__.py,sha256=39G34a2k5ktVtBAh4N4RMsePEak5rzPDhbNNdXo-Ye4,2258
77
77
  tensorneko/util/configuration.py,sha256=xXeAjDh1FCNTmSPwDdkL-uH-ULfzFF6Fg0LT7gsZ6nQ,2510
78
78
  tensorneko/util/dispatched_misc.py,sha256=_0Go7XezdYB7bpMnCs1MDD_6mPNoWP5qt8DoKuPxynI,997
79
79
  tensorneko/util/gc.py,sha256=P3bOZ-2VUNyswnfVz5xfj__ecTSAHpu_kLp2wFcpb6M,185
@@ -86,8 +86,8 @@ tensorneko/visualization/log_graph.py,sha256=NvOwWVc_petXWYdgaHosPFLa43sHBeacbYc
86
86
  tensorneko/visualization/matplotlib.py,sha256=xs9Ssc44ojZX65QU8-fftA7Ug_pBuZ3TBtM8vETNq9w,1568
87
87
  tensorneko/visualization/image_browser/__init__.py,sha256=AtykhAE3bXQS6SOWbeYFeeUE9ts9XOFMvrL31z0LoMg,63
88
88
  tensorneko/visualization/watcher/__init__.py,sha256=Nq752qIYvfRUZ8VctKQRSqhxh5KmFbWcqPfZlijVx6s,379
89
- tensorneko-0.3.17.dist-info/LICENSE,sha256=Vd75kwgJpVuMnCRBWasQzceMlXt4YQL13ikBLy8G5h0,1067
90
- tensorneko-0.3.17.dist-info/METADATA,sha256=cezBL0TH1cdY0Sct7GCGLK9jW0ckUXZIL7PRScTckBg,19998
91
- tensorneko-0.3.17.dist-info/WHEEL,sha256=g4nMs7d-Xl9-xC9XovUrsDHGXt-FT0E17Yqo92DEfvY,92
92
- tensorneko-0.3.17.dist-info/top_level.txt,sha256=sZHwlP0iyk7_zHuhRHzSBkdY9yEgyC48f6UVuZ6CvqE,11
93
- tensorneko-0.3.17.dist-info/RECORD,,
89
+ tensorneko-0.3.19.dist-info/LICENSE,sha256=Vd75kwgJpVuMnCRBWasQzceMlXt4YQL13ikBLy8G5h0,1067
90
+ tensorneko-0.3.19.dist-info/METADATA,sha256=9M81K0ENnhym0S-qyKQTlQzSADnpOZIszNWJ66q4IuM,19998
91
+ tensorneko-0.3.19.dist-info/WHEEL,sha256=g4nMs7d-Xl9-xC9XovUrsDHGXt-FT0E17Yqo92DEfvY,92
92
+ tensorneko-0.3.19.dist-info/top_level.txt,sha256=sZHwlP0iyk7_zHuhRHzSBkdY9yEgyC48f6UVuZ6CvqE,11
93
+ tensorneko-0.3.19.dist-info/RECORD,,