potnn 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
potnn/config.py ADDED
@@ -0,0 +1,112 @@
1
+ """Configuration module for potnn."""
2
+
3
+ from typing import Optional, Union, List, Dict
4
+
5
+
6
+ # 지원하는 인코딩 목록
7
+ VALID_ENCODINGS = {'unroll', 'fp130', '5level', '2bit', 'ternary'}
8
+
9
+
10
+ class Config:
11
+ """Configuration for potnn model compilation.
12
+
13
+ Args:
14
+ flash: Flash memory budget in bytes (e.g., 16384 for 16KB)
15
+ ram: RAM memory budget in bytes (e.g., 2048 for 2KB)
16
+ input_norm: Input normalization method (255, 256, or 'standardize')
17
+ mean: Mean for standardization. Can be float (1-channel) or List[float] (multi-channel)
18
+ std: Standard deviation for standardization. Can be float or List[float]
19
+ input_h: Input height (default 16 for 16x16 MNIST)
20
+ input_w: Input width (default 16 for 16x16 MNIST)
21
+ input_channels: Number of input channels (default 1 for grayscale, 3 for RGB)
22
+ layer_encodings: Dict mapping layer names to encoding types.
23
+ Example: {'conv1': 'unroll', 'fc': '5level'}
24
+ Valid encodings: 'unroll', 'fp130', '5level', '2bit', 'ternary'
25
+ default_encoding: Default encoding for layers not in layer_encodings.
26
+ Default is 'unroll' for backward compatibility.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ flash: int,
32
+ ram: int,
33
+ input_norm: Optional[int] = 256,
34
+ mean: Optional[Union[float, List[float]]] = None,
35
+ std: Optional[Union[float, List[float]]] = None,
36
+ input_h: int = 16,
37
+ input_w: int = 16,
38
+ input_channels: int = 1,
39
+ layer_encodings: Optional[Dict[str, str]] = None,
40
+ default_encoding: str = 'unroll',
41
+ ):
42
+ self.flash = flash
43
+ self.ram = ram
44
+ self.input_norm = input_norm
45
+ self.input_h = input_h
46
+ self.input_w = input_w
47
+ self.input_channels = input_channels
48
+
49
+ # Normalize mean/std to list format for consistency
50
+ # 1-channel: mean=0.5 -> [0.5]
51
+ # 3-channel: mean=[0.4914, 0.4822, 0.4465] -> as-is
52
+ if mean is not None:
53
+ if isinstance(mean, (int, float)):
54
+ self.mean = [float(mean)]
55
+ else:
56
+ self.mean = [float(m) for m in mean]
57
+ else:
58
+ self.mean = None
59
+
60
+ if std is not None:
61
+ if isinstance(std, (int, float)):
62
+ self.std = [float(std)]
63
+ else:
64
+ self.std = [float(s) for s in std]
65
+ else:
66
+ self.std = None
67
+
68
+ # Validate input_norm
69
+ if input_norm not in [255, 256, 'standardize', None]:
70
+ raise ValueError(f"input_norm must be 255, 256, 'standardize', or None, got {input_norm}")
71
+
72
+ if input_norm == 'standardize' and (self.mean is None or self.std is None):
73
+ raise ValueError("mean and std must be provided when input_norm='standardize'")
74
+
75
+ # Validate mean/std length matches input_channels
76
+ if self.mean is not None and len(self.mean) != input_channels:
77
+ raise ValueError(f"mean length ({len(self.mean)}) must match input_channels ({input_channels})")
78
+ if self.std is not None and len(self.std) != input_channels:
79
+ raise ValueError(f"std length ({len(self.std)}) must match input_channels ({input_channels})")
80
+
81
+ # Store normalization type
82
+ if self.mean is not None and self.std is not None:
83
+ self.use_standardization = True
84
+ else:
85
+ self.use_standardization = False
86
+
87
+ # Validate and store encoding settings
88
+ if default_encoding not in VALID_ENCODINGS:
89
+ raise ValueError(
90
+ f"Invalid default_encoding '{default_encoding}'. "
91
+ f"Valid options: {VALID_ENCODINGS}"
92
+ )
93
+ self.default_encoding = default_encoding
94
+
95
+ self.layer_encodings = layer_encodings or {}
96
+ for layer_name, encoding in self.layer_encodings.items():
97
+ if encoding not in VALID_ENCODINGS:
98
+ raise ValueError(
99
+ f"Invalid encoding '{encoding}' for layer '{layer_name}'. "
100
+ f"Valid options: {VALID_ENCODINGS}"
101
+ )
102
+
103
+ def get_encoding(self, layer_name: str) -> str:
104
+ """Get encoding for a specific layer.
105
+
106
+ Args:
107
+ layer_name: Name of the layer (e.g., 'conv1', 'features.0')
108
+
109
+ Returns:
110
+ Encoding type ('unroll', 'fp130', '5level', '2bit', 'ternary')
111
+ """
112
+ return self.layer_encodings.get(layer_name, self.default_encoding)