tencheck 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tencheck-0.0.2/LICENSE ADDED
@@ -0,0 +1,19 @@
1
+ Copyright (c) 2024-present Justin Yan
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
@@ -0,0 +1,49 @@
1
+ Metadata-Version: 2.1
2
+ Name: tencheck
3
+ Version: 0.0.2
4
+ Summary: A library for pytorch layer testing.
5
+ Author-email: Justin Yan <justin@iomorphic.com>
6
+ Project-URL: Homepage, https://github.com/justin-yan/tencheck
7
+ Classifier: Intended Audience :: Developers
8
+ Classifier: Operating System :: OS Independent
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Programming Language :: Python
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Requires-Python: >=3.11
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ Requires-Dist: torch<2.4,>2
17
+ Requires-Dist: jaxtyping
18
+ Requires-Dist: numpy
19
+
20
+ # tencheck
21
+
22
+ `tencheck` provides a simple set of utilities for analyzing and validating the properties of layers of deep neural nets.
23
+
24
+ It's typically quite difficult to validate that a layer "behaves properly" (oftentimes, the final barometer is simply how well a model performs with the layer included), and many brittle unit tests have been written involving randomly instantiated tensors and `is_close` checks. We believe a good "first line of defense" for neural nets is to create a suite of properties that can be asserted about a layer, while requiring minimal effort per layer in order to do so.
25
+
26
+ We think there are two aspects of property-based testing that are quite useful to take inspiration from:
27
+
28
+ - Automatically generating inputs (and generating inputs of variable sizes and values to elucidate properties of interest).
29
+ - Evaluating properties based on the maintenance of invariants instead of attempting to exactly match values (which is particularly difficult to interpret in deep neural nets).
30
+
31
+ However, an important difference is that the properties of interest are generally fairly generic and often shared between layers, while the input generation strategies are pretty similar (they're all tensors). So the focus of `tencheck` is to provide:
32
+
33
+ - An (attempted) universal input generation harness.
34
+ - A variety of interesting properties.
35
+ - Three modalities: assertion, analysis, and profiling.
36
+
37
+ The following requirements need to be met for `tencheck` to work:
38
+
39
+ - Your layers are implemented in torch.
40
+ - The `.forward()` method is annotated with [jaxtyping](https://github.com/patrick-kidger/jaxtyping).
41
+
42
+
43
+ ## Backlog
44
+
45
+ - For profiling, use a grid of input sizes to generate performance curves.
46
+ - Pick a flop counter and use for profiling
47
+ - Tensor container types include more options like dataclasses.
48
+ - Auto-generate simple hyperparameters for layer instantiation.
49
+ - Refine dtype mapping and coherence.
@@ -0,0 +1,30 @@
1
+ # tencheck
2
+
3
+ `tencheck` provides a simple set of utilities for analyzing and validating the properties of layers of deep neural nets.
4
+
5
+ It's typically quite difficult to validate that a layer "behaves properly" (oftentimes, the final barometer is simply how well a model performs with the layer included), and many brittle unit tests have been written involving randomly instantiated tensors and `is_close` checks. We believe a good "first line of defense" for neural nets is to create a suite of properties that can be asserted about a layer, while requiring minimal effort per layer in order to do so.
6
+
7
+ We think there are two aspects of property-based testing that are quite useful to take inspiration from:
8
+
9
+ - Automatically generating inputs (and generating inputs of variable sizes and values to elucidate properties of interest).
10
+ - Evaluating properties based on the maintenance of invariants instead of attempting to exactly match values (which is particularly difficult to interpret in deep neural nets).
11
+
12
+ However, an important difference is that the properties of interest are generally fairly generic and often shared between layers, while the input generation strategies are pretty similar (they're all tensors). So the focus of `tencheck` is to provide:
13
+
14
+ - An (attempted) universal input generation harness.
15
+ - A variety of interesting properties.
16
+ - Three modalities: assertion, analysis, and profiling.
17
+
18
+ The following requirements need to be met for `tencheck` to work:
19
+
20
+ - Your layers are implemented in torch.
21
+ - The `.forward()` method is annotated with [jaxtyping](https://github.com/patrick-kidger/jaxtyping).
22
+
23
+
24
+ ## Backlog
25
+
26
+ - For profiling, use a grid of input sizes to generate performance curves.
27
+ - Pick a flop counter and use for profiling
28
+ - Tensor container types include more options like dataclasses.
29
+ - Auto-generate simple hyperparameters for layer instantiation.
30
+ - Refine dtype mapping and coherence.
@@ -0,0 +1,90 @@
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "tencheck"
7
+ version = "0.0.2"
8
+ authors = [
9
+ { name="Justin Yan", email="justin@iomorphic.com" }
10
+ ]
11
+ description = "A library for pytorch layer testing."
12
+ readme = "README.md"
13
+ requires-python = ">=3.11"
14
+ classifiers = [
15
+ "Intended Audience :: Developers",
16
+ "Operating System :: OS Independent",
17
+ "License :: OSI Approved :: MIT License",
18
+ "Programming Language :: Python",
19
+ "Programming Language :: Python :: 3",
20
+ "Programming Language :: Python :: 3.11",
21
+ ]
22
+ dependencies = [
23
+ ######
24
+ ### Custom Dependencies Section Begin
25
+ ######
26
+ "torch>2,<2.4",
27
+ "jaxtyping",
28
+ "numpy",
29
+ ######
30
+ ### Custom Dependencies Section End
31
+ ######
32
+ ]
33
+
34
+ [project.urls]
35
+ "Homepage" = "https://github.com/justin-yan/tencheck"
36
+
37
+ [tool.setuptools]
38
+ zip-safe = false
39
+ include-package-data = true
40
+
41
+ [tool.setuptools.package-data]
42
+ "tencheck" = ["py.typed"]
43
+
44
+ [tool.setuptools.packages.find]
45
+ where = ["src"]
46
+
47
+ #######
48
+ ### Miscellaneous Tool Configuration
49
+ #######
50
+ [tool.ruff]
51
+ line-length = 150
52
+ target-version = "py311"
53
+
54
+ [tool.ruff.format]
55
+ quote-style = "double"
56
+
57
+ [tool.ruff.lint]
58
+ select = ["E", "F", "W", "I"]
59
+ ignore = ["F722"]
60
+
61
+ [tool.ruff.lint.isort]
62
+ known-first-party = ["tencheck"]
63
+
64
+ [tool.pytest.ini_options]
65
+ addopts = "-ra -q --doctest-modules --jaxtyping-packages=tencheck.examples,beartype.beartype"
66
+
67
+ [tool.mypy]
68
+ mypy_path = "src"
69
+ disallow_untyped_defs = true
70
+ disallow_any_unimported = true
71
+ allow_redefinition = false
72
+ ignore_errors = false
73
+ implicit_reexport = false
74
+ local_partial_types = true
75
+ no_implicit_optional = true
76
+ strict_equality = true
77
+ strict_optional = true
78
+ warn_no_return = true
79
+ warn_redundant_casts = true
80
+ warn_unreachable = true
81
+ warn_unused_configs = true
82
+ warn_unused_ignores = true
83
+
84
+ ######
85
+ ### Custom Directives Section Begin
86
+ ######
87
+
88
+ ######
89
+ ### Custom Directives Section End
90
+ ######
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
File without changes
@@ -0,0 +1,15 @@
1
+ from torch import nn as nn
2
+
3
+
4
+ def unused_params_check(layer: nn.Module) -> None:
5
+ """
6
+ This check is run after a backward pass is completed.
7
+
8
+ If any unused parameters are found, an exception is thrown with the named parameters.
9
+ """
10
+ unused_parameters = []
11
+ for name, param in layer.named_parameters():
12
+ if param.grad is None:
13
+ unused_parameters.append(name)
14
+
15
+ assert len(unused_parameters) == 0, f"Unused parameters: {unused_parameters} detected."
@@ -0,0 +1,67 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from jaxtyping import Float
4
+
5
+
6
+ class SimpleLinReluModule(nn.Module):
7
+ def __init__(self, out_features: int) -> None:
8
+ super(SimpleLinReluModule, self).__init__()
9
+ self.linear = nn.Linear(32, out_features)
10
+ self.relu = nn.ReLU()
11
+
12
+ def forward(self, x: Float[torch.Tensor, "B 32"]) -> Float[torch.Tensor, "B O"]:
13
+ x = self.linear(x)
14
+ x = self.relu(x)
15
+ return x
16
+
17
+
18
+ class CasedLinReluModule(nn.Module):
19
+ _tencheck_cases = [{"out_features": 10}, {"out_features": 20}]
20
+
21
+ def __init__(self, out_features: int) -> None:
22
+ super(CasedLinReluModule, self).__init__()
23
+ self.linear = nn.Linear(32, out_features)
24
+ self.relu = nn.ReLU()
25
+
26
+ def forward(self, x: Float[torch.Tensor, "B 32"]) -> Float[torch.Tensor, "B O"]:
27
+ x = self.linear(x)
28
+ x = self.relu(x)
29
+ return x
30
+
31
+
32
+ class UnusedParamsModule(nn.Module):
33
+ def __init__(self, out_features: int) -> None:
34
+ super(UnusedParamsModule, self).__init__()
35
+ self.linear = nn.Linear(32, out_features)
36
+ self.unused_linear = nn.Linear(32, out_features)
37
+ self.relu = nn.ReLU()
38
+
39
+ def forward(self, x: Float[torch.Tensor, "B 32"]) -> Float[torch.Tensor, "B O"]:
40
+ x = self.linear(x)
41
+ x = self.relu(x)
42
+ return x
43
+
44
+
45
+ class BrokenModule(nn.Module):
46
+ def __init__(self) -> None:
47
+ super(BrokenModule, self).__init__()
48
+ self.linear = nn.Linear(32, 4)
49
+ self.relu = nn.ReLU()
50
+
51
+ def forward(self, x: Float[torch.Tensor, "B 32"]) -> Float[torch.Tensor, "B 4"]:
52
+ x = self.linear(x)
53
+ raise Exception("Module is broken")
54
+ x = self.relu(x) # type: ignore[unreachable]
55
+ return x
56
+
57
+
58
+ class MistypedModule(nn.Module):
59
+ def __init__(self, out_features: int) -> None:
60
+ super(MistypedModule, self).__init__()
61
+ self.linear = nn.Linear(32, out_features)
62
+ self.relu = nn.ReLU()
63
+
64
+ def forward(self, x: Float[torch.Tensor, "B 32"]) -> Float[torch.Tensor, "B 1025"]:
65
+ x = self.linear(x)
66
+ x = self.relu(x)
67
+ return x
@@ -0,0 +1,68 @@
1
+ import time
2
+ from typing import Optional, Type
3
+
4
+ import torch
5
+ from torch import nn as nn
6
+
7
+ from tencheck.checks import unused_params_check
8
+ from tencheck.input import input_gen
9
+ from tencheck.loss import trivial_loss
10
+ from tencheck.ttypes import CaseDefined, LayerStats
11
+
12
+
13
+ def check_layers(layers: list[nn.Module | Type[CaseDefined]], seed: Optional[int] = None) -> None:
14
+ """
15
+ This method receives a *concrete* list of layer objects, and asserts the relevant properties.
16
+ """
17
+ for layer in layers:
18
+ if isinstance(layer, CaseDefined): # This works even though layer is a class, not an obj, due to @runtime_checkable
19
+ for case in layer._tencheck_cases:
20
+ layer_obj = layer(**case)
21
+ _single_layer_assert_all(layer_obj, seed)
22
+ else:
23
+ _single_layer_assert_all(layer, seed) # type: ignore[arg-type]
24
+
25
+
26
+ def _single_layer_assert_all(layer: nn.Module, seed: Optional[int] = None) -> None:
27
+ in_tens = input_gen(layer, seed)
28
+ layer.zero_grad(set_to_none=True)
29
+
30
+ # throws TypeCheckError for shapecheck
31
+ # throws Exception for generic issues
32
+ out = layer.forward(**in_tens)
33
+ loss = trivial_loss(out)
34
+ loss.backward()
35
+
36
+ unused_params_check(layer)
37
+
38
+
39
+ def profile_layer(layer: nn.Module) -> LayerStats:
40
+ """
41
+ This runs a basic profiling setup for a single layer.
42
+ """
43
+ if torch.cuda.is_available():
44
+ device = torch.device("cuda")
45
+ else:
46
+ raise Exception("GPU is not available")
47
+
48
+ layer = layer.to(device)
49
+ layer.zero_grad(set_to_none=True)
50
+ in_tens = input_gen(layer, device=device)
51
+
52
+ torch.cuda.reset_peak_memory_stats()
53
+
54
+ start = time.perf_counter()
55
+ out = layer.forward(**in_tens)
56
+ loss = trivial_loss(out)
57
+ loss.backward()
58
+ elapsed = time.perf_counter() - start
59
+ peak_memory_gbs = torch.cuda.max_memory_allocated() / 1024**3
60
+ gigaflops = 0.0
61
+
62
+ return LayerStats(elapsed, peak_memory_gbs, gigaflops)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ from tencheck.examples import SimpleLinReluModule
67
+
68
+ print(profile_layer(SimpleLinReluModule(5)))
@@ -0,0 +1,84 @@
1
+ import inspect
2
+ import logging
3
+ import random
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from jaxtyping._array_types import _FixedDim, _NamedDim, _NamedVariadicDim, _SymbolicDim
10
+ from torch import Tensor
11
+ from torch.testing import make_tensor
12
+
13
+ dtype_mapping = {
14
+ "bool": torch.bool,
15
+ "bool_": torch.bool,
16
+ "uint4": torch.int32,
17
+ "uint8": torch.int32,
18
+ "uint16": torch.int32,
19
+ "uint32": torch.int32,
20
+ "uint64": torch.int32,
21
+ "int4": torch.int32,
22
+ "int8": torch.int32,
23
+ "int16": torch.int32,
24
+ "int32": torch.int32,
25
+ "int64": torch.int32,
26
+ "bfloat16": torch.float32,
27
+ "float16": torch.float32,
28
+ "float32": torch.float32,
29
+ "float64": torch.float32,
30
+ "complex64": torch.float32,
31
+ "complex128": torch.float32,
32
+ }
33
+
34
+
35
+ @dataclass
36
+ class TensorSpec:
37
+ shape: tuple[int, ...]
38
+ dtype: torch.dtype
39
+
40
+
41
+ def input_gen(layer: nn.Module, seed: Optional[int] = None, device: str | torch.device = "cpu") -> dict[str, Tensor]:
42
+ """
43
+ For a given layer that is type annotated with jaxtyping, produce a map of mock tensors that can be used like so:
44
+
45
+ in_tens = input_gen(layer)
46
+ layer.forward(**in_tens)
47
+ """
48
+ if seed:
49
+ torch.manual_seed(seed)
50
+ random.seed(seed)
51
+
52
+ signature = inspect.signature(layer.forward)
53
+ tensor_specs: dict[str, TensorSpec] = {} # parameter to shape
54
+ # Across all of the parameters, we'll have a mix of `_NamedDim`, `_FixedDim`, `_NamedVariadicDim`, or `_SymbolicDim`
55
+ # We want to pick concrete dimensions for everything that isn't fixed,
56
+ # and then we want to generate tensors for everything.
57
+ dimension_name_map: dict[str, int] = {}
58
+ for name, param_obj in signature.parameters.items():
59
+ shape: list[int] = []
60
+ for dim in param_obj.annotation.dims:
61
+ match dim:
62
+ case _NamedDim(nm, _, _):
63
+ sz = dimension_name_map.setdefault(nm, random.randint(4, 64))
64
+ shape.append(sz)
65
+ case _FixedDim(sz, _):
66
+ shape.append(sz)
67
+ case _NamedVariadicDim() | _SymbolicDim() | _:
68
+ raise NotImplementedError("Don't yet handle these dimension cases")
69
+ dt = dtype_mapping[param_obj.annotation.dtypes[0]]
70
+ tensor_specs[name] = TensorSpec(tuple(shape), dt)
71
+
72
+ output = {}
73
+ for name, spec in tensor_specs.items():
74
+ match spec.dtype:
75
+ case torch.float32:
76
+ mock_ten = make_tensor(spec.shape, dtype=spec.dtype, device=device, low=-1, high=1)
77
+ case torch.int32:
78
+ mock_ten = make_tensor(spec.shape, dtype=spec.dtype, device=device, low=-10, high=10)
79
+ case _:
80
+ logging.error(spec)
81
+ raise NotImplementedError("Don't yet handle these dtypes")
82
+ output[name] = mock_ten
83
+
84
+ return output
@@ -0,0 +1,34 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ TensorContainerTypes = Tensor | list | set | tuple | dict
5
+
6
+
7
+ def flatten_tensors(tct: TensorContainerTypes) -> list[Tensor]:
8
+ """
9
+ This method will recursively traverse any potential container structures and extract any tensors as a flat list.
10
+ """
11
+ tensor_list: list[Tensor] = []
12
+ match tct:
13
+ case Tensor():
14
+ tensor_list.append(tct)
15
+ case list():
16
+ for e in tct:
17
+ match e:
18
+ case Tensor():
19
+ tensor_list.append(e)
20
+ case _:
21
+ tensor_list.extend(flatten_tensors(e))
22
+ case set() | tuple():
23
+ tensor_list.extend(flatten_tensors(list(tct)))
24
+ case dict():
25
+ tensor_list.extend(flatten_tensors([v for k, v in tct.items()]))
26
+ case _:
27
+ raise NotImplementedError("Unexpected input format")
28
+
29
+ return tensor_list
30
+
31
+
32
+ def trivial_loss(tct: TensorContainerTypes) -> Tensor:
33
+ partial_sums = [t.sum() for t in flatten_tensors(tct)]
34
+ return torch.stack(partial_sums).sum()
File without changes
@@ -0,0 +1,16 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Protocol, runtime_checkable
3
+
4
+
5
+ @runtime_checkable
6
+ class CaseDefined(Protocol):
7
+ @property
8
+ def _tencheck_cases(self) -> List[Dict[str, Any]]:
9
+ raise NotImplementedError()
10
+
11
+
12
+ @dataclass
13
+ class LayerStats:
14
+ total_time: float
15
+ peak_mem_gigs: float
16
+ giga_flop_count: float
@@ -0,0 +1,49 @@
1
+ Metadata-Version: 2.1
2
+ Name: tencheck
3
+ Version: 0.0.2
4
+ Summary: A library for pytorch layer testing.
5
+ Author-email: Justin Yan <justin@iomorphic.com>
6
+ Project-URL: Homepage, https://github.com/justin-yan/tencheck
7
+ Classifier: Intended Audience :: Developers
8
+ Classifier: Operating System :: OS Independent
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Programming Language :: Python
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Requires-Python: >=3.11
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ Requires-Dist: torch<2.4,>2
17
+ Requires-Dist: jaxtyping
18
+ Requires-Dist: numpy
19
+
20
+ # tencheck
21
+
22
+ `tencheck` provides a simple set of utilities for analyzing and validating the properties of layers of deep neural nets.
23
+
24
+ It's typically quite difficult to validate that a layer "behaves properly" (oftentimes, the final barometer is simply how well a model performs with the layer included), and many brittle unit tests have been written involving randomly instantiated tensors and `is_close` checks. We believe a good "first line of defense" for neural nets is to create a suite of properties that can be asserted about a layer, while requiring minimal effort per layer in order to do so.
25
+
26
+ We think there are two aspects of property-based testing that are quite useful to take inspiration from:
27
+
28
+ - Automatically generating inputs (and generating inputs of variable sizes and values to elucidate properties of interest).
29
+ - Evaluating properties based on the maintenance of invariants instead of attempting to exactly match values (which is particularly difficult to interpret in deep neural nets).
30
+
31
+ However, an important difference is that the properties of interest are generally fairly generic and often shared between layers, while the input generation strategies are pretty similar (they're all tensors). So the focus of `tencheck` is to provide:
32
+
33
+ - An (attempted) universal input generation harness.
34
+ - A variety of interesting properties.
35
+ - Three modalities: assertion, analysis, and profiling.
36
+
37
+ The following requirements need to be met for `tencheck` to work:
38
+
39
+ - Your layers are implemented in torch.
40
+ - The `.forward()` method is annotated with [jaxtyping](https://github.com/patrick-kidger/jaxtyping).
41
+
42
+
43
+ ## Backlog
44
+
45
+ - For profiling, use a grid of input sizes to generate performance curves.
46
+ - Pick a flop counter and use for profiling
47
+ - Tensor container types include more options like dataclasses.
48
+ - Auto-generate simple hyperparameters for layer instantiation.
49
+ - Refine dtype mapping and coherence.
@@ -0,0 +1,17 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/tencheck/__init__.py
5
+ src/tencheck/checks.py
6
+ src/tencheck/examples.py
7
+ src/tencheck/harness.py
8
+ src/tencheck/input.py
9
+ src/tencheck/loss.py
10
+ src/tencheck/py.typed
11
+ src/tencheck/ttypes.py
12
+ src/tencheck.egg-info/PKG-INFO
13
+ src/tencheck.egg-info/SOURCES.txt
14
+ src/tencheck.egg-info/dependency_links.txt
15
+ src/tencheck.egg-info/not-zip-safe
16
+ src/tencheck.egg-info/requires.txt
17
+ src/tencheck.egg-info/top_level.txt
@@ -0,0 +1,3 @@
1
+ torch<2.4,>2
2
+ jaxtyping
3
+ numpy
@@ -0,0 +1 @@
1
+ tencheck