cracknuts-squirrel 0.0.1b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1-beta.0"
@@ -0,0 +1,126 @@
1
+ # Copyright 2024 CrackNuts. All rights reserved.
2
+
3
+ import numpy as np
4
+ import zarr
5
+ from cracknuts_squirrel.preprocessing_basic import PPBasic
6
+ import matplotlib.pyplot as plt
7
+
8
+ # Hamming weights of the values 0-255 used for model values
9
+ WEIGHTS = np.array([0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4,
10
+ 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
11
+ 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
12
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
13
+ 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
14
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
15
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
16
+ 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
17
+ 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
18
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
19
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
20
+ 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
21
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
22
+ 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
23
+ 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
24
+ 4,5,5,6,5,6,6,7,5,6,6,7,6,7,7,8],
25
+ np.float32)
26
+
27
+ class CorrelationAnalysis(PPBasic):
28
+ """
29
+ 通用相关系数分析类,继承自预处理基类PPBasic
30
+ """
31
+ def __init__(self, input_path=None, output_path=None, sample_range=(0, None), **kwargs):
32
+ super().__init__(input_path=input_path, output_path=output_path, **kwargs)
33
+ self.sample_range = sample_range
34
+
35
+ def calculate_correlation(self, traces, plaintext_bytes):
36
+ """
37
+ 计算轨迹(traces)与明文字节(plaintext_bytes)之间的相关系数
38
+
39
+ 参数:
40
+ traces: numpy数组,形状为(n_traces, n_samples),包含功率轨迹数据
41
+ plaintext_bytes: numpy数组,形状为(n_traces, n_bytes),包含明文字节数据
42
+
43
+ 返回:
44
+ correlations: numpy数组,形状为(n_bytes, n_samples),包含每个字节位置与每个样本点的相关系数
45
+ """
46
+ # 确保输入是numpy数组
47
+ traces = np.asarray(traces)
48
+ plaintext_bytes = np.asarray(plaintext_bytes)
49
+
50
+ # 检查输入维度是否匹配
51
+ if traces.shape[0] != plaintext_bytes.shape[0]:
52
+ raise ValueError("轨迹数量与明文数量不匹配")
53
+
54
+ n_traces, n_samples = traces.shape
55
+ n_bytes = plaintext_bytes.shape[1]
56
+
57
+ # 初始化相关系数结果数组
58
+ correlations = np.zeros((n_bytes, n_samples))
59
+
60
+ # 对每个字节位置计算相关系数
61
+ for byte_idx in range(n_bytes):
62
+ # 提取当前字节数据
63
+ byte_data = plaintext_bytes[:, byte_idx]
64
+
65
+ # 对每个样本点计算相关系数
66
+ for sample_idx in range(n_samples):
67
+ # 提取当前样本点的轨迹数据
68
+ trace_data = traces[:, sample_idx]
69
+
70
+ # 计算皮尔逊相关系数
71
+ correlations[byte_idx, sample_idx] = np.corrcoef(byte_data, trace_data)[0, 1]
72
+
73
+ return correlations
74
+
75
+ def perform_analysis(self):
76
+ """
77
+ 执行相关系数分析(使用明文字节的汉明重量作为模型值)
78
+ """
79
+ store = zarr.DirectoryStore(self.output_path)
80
+ root = zarr.group(store=store, overwrite=True)
81
+
82
+ # 获取处理后的轨迹数据
83
+ processed_traces = self.t[:, self.sample_range[0]:self.sample_range[1]]
84
+
85
+ # 获取明文字节并计算汉明重量
86
+ plaintext = self.plaintext[:self.sel_num_traces, :16]
87
+ hw_matrix = WEIGHTS[plaintext]
88
+
89
+ # 计算相关系数矩阵
90
+ correlation_matrix = self.calculate_correlation(
91
+ traces=processed_traces,
92
+ plaintext_bytes=hw_matrix
93
+ )
94
+
95
+ # 存储结果
96
+ root.create_dataset(
97
+ '/0/0/correlation',
98
+ data=correlation_matrix,
99
+ chunks=(16, 1000)
100
+ )
101
+
102
+ # 添加元数据
103
+ root.attrs.update({
104
+ "analysis_metadata": {
105
+ "sample_range": self.sample_range,
106
+ "trace_count": self.sel_num_traces,
107
+ "model_type": "hamming_weight"
108
+ }
109
+ })
110
+ return correlation_matrix
111
+
112
+ if __name__ == "__main__":
113
+ # 示例用法
114
+ analyzer = CorrelationAnalysis(input_path='E:\\codes\\Acquisition\\dataset\\20250722204543.zarr')
115
+ analyzer.auto_out_filename()
116
+ # analyzer.set_range(sample_range=(500, 10000))
117
+
118
+ result = analyzer.perform_analysis()
119
+
120
+ print(result.shape)
121
+ for i in range(16):
122
+ plt.plot(result[i,:])
123
+
124
+ plt.show()
125
+
126
+ print(f"分析完成,最大相关系数:{np.nanmax(np.abs(result)):.4f}")
@@ -0,0 +1,201 @@
1
+ # Copyright 2024 CrackNuts. All rights reserved.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import zarr
8
+
9
+ from cracknuts_squirrel.preprocessing_basic import PPBasic
10
+
11
+ @dataclass
12
+ class AnalysisParams:
13
+ data_type: str = "plaintext"
14
+ data_width: int = 1
15
+ start: int = 0
16
+ count: Optional[int] = None
17
+
18
+ class CorrelationAnalysis(PPBasic):
19
+ # Hamming weights of the values 0-255 used for model values
20
+ _HW_TABLE = np.array([
21
+ 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
22
+ 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
23
+ 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
24
+ 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
25
+ 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
26
+ 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
27
+ 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
28
+ 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
29
+ 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
30
+ 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
31
+ 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
32
+ 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
33
+ 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
34
+ 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
35
+ 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
36
+ 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
37
+ ], np.uint8)
38
+
39
+ """
40
+ 通用相关系数分析类,继承自预处理基类PPBasic
41
+ """
42
+ def __init__(self, input_path=None, output_path=None, sample_range=(0, None), **kwargs):
43
+ super().__init__(input_path=input_path, output_path=output_path, **kwargs)
44
+ self.sample_range = sample_range
45
+
46
+ @classmethod
47
+ def hamming_weight(cls, data: np.ndarray, bit_width: int) -> np.ndarray:
48
+
49
+ if data.ndim != 2:
50
+ raise ValueError(f"输入必须是二维数组,当前形状为 {data.shape}")
51
+
52
+ if data.dtype != np.uint8:
53
+ raise ValueError("输入数组必须是 uint8 类型")
54
+
55
+ if bit_width == 8:
56
+ return cls._HW_TABLE[data]
57
+
58
+ n, m = data.shape
59
+
60
+ if bit_width > 8:
61
+ if bit_width not in (16, 32, 64):
62
+ raise ValueError("合并仅支持 16, 32, 64 位")
63
+ num_bytes = bit_width // 8
64
+ if m % num_bytes != 0:
65
+ raise ValueError(f"列数 {m} 不能被 {num_bytes} 整除")
66
+
67
+ reshaped = data.reshape(n, m // num_bytes, num_bytes)
68
+ tmp_hw = cls._HW_TABLE[reshaped]
69
+
70
+ return tmp_hw.astype(f"uint{bit_width}").sum(axis=2)
71
+
72
+ else: # bit_length < 8
73
+ if bit_width not in (1, 2, 4):
74
+ raise ValueError("拆分仅支持 1, 2, 4 位")
75
+ if 8 % bit_width != 0:
76
+ raise ValueError(f"8 不能被 {bit_width} 整除")
77
+ splits_per_byte = 8 // bit_width
78
+ new_m = m * splits_per_byte
79
+ result = np.zeros((n, new_m), dtype=np.uint8)
80
+
81
+ for i in range(splits_per_byte):
82
+ shift = 8 - bit_width * (i + 1)
83
+ mask = (1 << bit_width) - 1
84
+ result[:, i::splits_per_byte] = (data >> shift) & mask
85
+
86
+ return cls._HW_TABLE[result]
87
+
88
+ @staticmethod
89
+ def calculate_correlation(traces, data_bytes):
90
+ """
91
+ 计算轨迹(traces)与明文字节(plaintext_bytes)之间的相关系数
92
+
93
+ 参数:
94
+ traces: numpy数组,形状为(n_traces, n_samples),包含功率轨迹数据
95
+ data_bytes: numpy数组,形状为(n_traces, n_bytes),包含明文字节数据
96
+
97
+ 返回:
98
+ correlations: numpy数组,形状为(n_bytes, n_samples),包含每个字节位置与每个样本点的相关系数
99
+ """
100
+ # 确保输入是numpy数组
101
+ traces = np.asarray(traces)
102
+ data_bytes = np.asarray(data_bytes)
103
+
104
+ # 检查输入维度是否匹配
105
+ if traces.shape[0] != data_bytes.shape[0]:
106
+ raise ValueError("轨迹数量与明文数量不匹配")
107
+
108
+ n_traces, n_samples = traces.shape
109
+ n_bytes = data_bytes.shape[1]
110
+
111
+ # 初始化相关系数结果数组
112
+ correlations = np.zeros((n_bytes, n_samples))
113
+
114
+ # 对每个字节位置计算相关系数
115
+ for byte_idx in range(n_bytes):
116
+ # 提取当前字节数据
117
+ byte_data = data_bytes[:, byte_idx]
118
+
119
+ # 对每个样本点计算相关系数
120
+ for sample_idx in range(n_samples):
121
+ # 提取当前样本点的轨迹数据
122
+ trace_data = traces[:, sample_idx]
123
+
124
+ if np.std(byte_data) == 0:
125
+ correlations[byte_idx, sample_idx] = 0 # 或者 np.nan
126
+ else:
127
+ correlations[byte_idx, sample_idx] = np.corrcoef(byte_data, trace_data)[0, 1]
128
+
129
+ return correlations
130
+
131
+ def perform_analysis(self, *analyzer_params: AnalysisParams, persist: bool=False) -> np.ndarray:
132
+ correlation_matrix_array = []
133
+ for analyzer_param in analyzer_params:
134
+ correlation_matrix_array.append(self.single_perform_analysis(**vars(analyzer_param)))
135
+
136
+ correlation_matrix = np.vstack(correlation_matrix_array)
137
+
138
+ if persist:
139
+ store = zarr.DirectoryStore(self.output_path)
140
+ root = zarr.group(store=store, overwrite=True)
141
+ # 存储结果
142
+ root.create_dataset(
143
+ '/0/0/correlation',
144
+ data=correlation_matrix,
145
+ chunks=(16, 1000)
146
+ )
147
+
148
+ # 添加元数据
149
+ root.attrs.update({
150
+ "analysis_metadata": {
151
+ "sample_range": self.sample_range,
152
+ "trace_count": self.sel_num_traces,
153
+ "model_type": "hamming_weight"
154
+ }
155
+ })
156
+
157
+ return correlation_matrix
158
+
159
+ def single_perform_analysis(self, data_type="plaintext", data_width=1, start: int = 0, count: int = None):
160
+ """
161
+ 执行相关系数分析(使用明文字节的汉明重量作为模型值)
162
+ """
163
+ # 获取处理后的轨迹数据
164
+ processed_traces = self.t[:, self.sample_range[0]:self.sample_range[1]]
165
+
166
+ if data_type == "plaintext":
167
+ # 获取明文字节并计算汉明重量
168
+ if count is None:
169
+ plaintext = self.plaintext[:self.sel_num_traces, start:]
170
+ else:
171
+ plaintext = self.plaintext[:self.sel_num_traces, start:start+count]
172
+ hw_matrix = self.hamming_weight(plaintext, data_width * 8)
173
+ elif data_type == "ciphertext":
174
+ if count is None:
175
+ ciphertext = self.ciphertext[:self.sel_num_traces, start:]
176
+ else:
177
+ ciphertext = self.ciphertext[:self.sel_num_traces, start:start+count]
178
+ hw_matrix = self.hamming_weight(ciphertext, data_width * 8)
179
+ elif data_type == "key":
180
+ if count is None:
181
+ key = self.key[:self.sel_num_traces, start:]
182
+ else:
183
+ key = self.key[:self.sel_num_traces, start:start+count]
184
+ hw_matrix = self.hamming_weight(key, data_width * 8)
185
+ elif data_type == "extended":
186
+ if count is None:
187
+ extended = self.extended[:self.sel_num_traces, start:]
188
+ else:
189
+ extended = self.extended[:self.sel_num_traces, start:start+count]
190
+ hw_matrix = self.hamming_weight(extended, data_width * 8)
191
+ else:
192
+ print(f"data_type error: [{data_type}].")
193
+ return
194
+
195
+ # 计算相关系数矩阵
196
+ correlation_matrix = self.calculate_correlation(
197
+ traces=processed_traces,
198
+ data_bytes=hw_matrix
199
+ )
200
+
201
+ return correlation_matrix
@@ -0,0 +1,250 @@
1
+ # Copyright 2024 CrackNuts. All rights reserved.
2
+
3
+ import numba as nb
4
+ import numpy as np
5
+ import zarr
6
+ from dask import delayed # 添加delayed导入
7
+ from dask.diagnostics import ProgressBar
8
+
9
+ from cracknuts_squirrel.preprocessing_basic import PPBasic
10
+
11
+ # AES-128 sbox used to compute model values
12
+ AES_SBOX = np.array([99,124,119,123,242,107,111,197,48,1,103,43,254,215,171,118,
13
+ 202,130,201,125,250,89,71,240,173,212,162,175,156,164,114,192,
14
+ 183,253,147,38,54,63,247,204,52,165,229,241,113,216,49,21,
15
+ 4,199,35,195,24,150,5,154,7,18,128,226,235,39,178,117,
16
+ 9,131,44,26,27,110,90,160,82,59,214,179,41,227,47,132,
17
+ 83,209,0,237,32,252,177,91,106,203,190,57,74,76,88,207,
18
+ 208,239,170,251,67,77,51,133,69,249,2,127,80,60,159,168,
19
+ 81,163,64,143,146,157,56,245,188,182,218,33,16,255,243,210,
20
+ 205,12,19,236,95,151,68,23,196,167,126,61,100,93,25,115,
21
+ 96,129,79,220,34,42,144,136,70,238,184,20,222,94,11,219,
22
+ 224,50,58,10,73,6,36,92,194,211,172,98,145,149,228,121,
23
+ 231,200,55,109,141,213,78,169,108,86,244,234,101,122,174,8,
24
+ 186,120,37,46,28,166,180,198,232,221,116,31,75,189,139,138,
25
+ 112,62,181,102,72,3,246,14,97,53,87,185,134,193,29,158,
26
+ 225,248,152,17,105,217,142,148,155,30,135,233,206,85,40,223,
27
+ 140,161,137,13,191,230,66,104,65,153,45,15,176,84,187,22])
28
+
29
+ AES_invSBOX = np.array([0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38,
30
+ 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB,
31
+ 0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87,
32
+ 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB,
33
+ 0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D,
34
+ 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E,
35
+ 0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2,
36
+ 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25,
37
+ 0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16,
38
+ 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92,
39
+ 0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA,
40
+ 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84,
41
+ 0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A,
42
+ 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06,
43
+ 0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02,
44
+ 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B,
45
+ 0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA,
46
+ 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73,
47
+ 0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85,
48
+ 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E,
49
+ 0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89,
50
+ 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B,
51
+ 0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20,
52
+ 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4,
53
+ 0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31,
54
+ 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F,
55
+ 0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D,
56
+ 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF,
57
+ 0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0,
58
+ 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61,
59
+ 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26,
60
+ 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D])
61
+ # Hamming weights of the values 0-255 used for model values
62
+ WEIGHTS = np.array([0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4,
63
+ 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
64
+ 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
65
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
66
+ 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
67
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
68
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
69
+ 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
70
+ 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
71
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
72
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
73
+ 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
74
+ 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
75
+ 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
76
+ 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
77
+ 4,5,5,6,5,6,6,7,5,6,6,7,6,7,7,8],
78
+ np.float32)
79
+
80
+ @nb.njit(parallel=True)
81
+ def inv_shift_rows(state):
82
+ """
83
+ AES逆向行移位操作
84
+ :param state: 16字节的状态矩阵(4x4)
85
+ :return: 逆向行移位后的状态
86
+ """
87
+ # 将1D数组转换为4x4矩阵
88
+ for i in nb.prange(state.shape[0]):
89
+ state1 = state[i].reshape(4,4)
90
+
91
+ # 对每一行进行不同的右移
92
+ state1[1] = np.roll(state1[1], -1)
93
+ state1[2] = np.roll(state1[2], -2)
94
+ state1[3] = np.roll(state1[3], -3)
95
+
96
+ state[i] = state1.flatten()
97
+
98
+ return state
99
+
100
+ class CPAAnalysis(PPBasic):
101
+ """
102
+ AES-CPA分析类,继承自预处理基类PPBasic
103
+ """
104
+ def __init__(self, input_path=None, output_path=None, byte_pos=list(range(16)), sample_range=(0, None), dr='enc', **kwargs):
105
+ super().__init__(input_path=input_path, output_path=output_path, **kwargs)
106
+ self.byte_pos = byte_pos if isinstance(byte_pos, (list, tuple)) else [byte_pos]
107
+ self.sample_range = sample_range
108
+ self.dr = dr
109
+
110
+
111
+
112
+ # def update(self, traces: np.ndarray, data: np.ndarray):
113
+ # # Update the number of rows processed
114
+ # self.trace_count += traces.shape[0]
115
+ # # Update sample accumulator
116
+ # self.sample_sum += np.sum(traces, axis=0)
117
+ # # Update sample squared accumulator
118
+ # self.sample_sq_sum += np.sum(np.square(traces), axis=0)
119
+ # # Update model accumulator
120
+ # self.model_sum += np.sum(data, axis=0)
121
+ # # Update model squared accumulator
122
+ # self.model_sq_sum += np.sum(np.square(data), axis=0)
123
+ # data = data.reshape((data.shape[0], -1))
124
+ # # Update product accumulator
125
+ # self.prod_sum += np.matmul(data.T, traces)
126
+
127
+ # def calculate(self):
128
+ # # Sample mean computation
129
+ # sample_mean = np.divide(self.sample_sum, self.trace_count)
130
+ # # Model mean computation
131
+ # model_mean = np.divide(self.model_sum, self.trace_count)
132
+
133
+ # prod_mean = np.divide(self.prod_sum, self.trace_count)
134
+ # # Calculate correlation coefficient numerator
135
+ # numerator = np.subtract(prod_mean, model_mean*sample_mean)
136
+ # # Calculate correlation coeefficient denominator sample part
137
+ # to_sqrt = np.subtract(np.divide(self.sample_sq_sum, self.trace_count), np.square(sample_mean))
138
+ # to_sqrt[to_sqrt < 0] = 0
139
+ # denom_sample = np.sqrt(to_sqrt)
140
+ # # Calculate correlation coefficient denominator model part
141
+ # to_sqrt = np.subtract(np.divide(self.model_sq_sum, self.trace_count), np.square(model_mean))
142
+ # to_sqrt = np.maximum(to_sqrt, 0)
143
+ # denom_model = np.sqrt(to_sqrt)
144
+
145
+ # denominator = denom_model*denom_sample
146
+
147
+ # denominator[denominator == 0] = 1
148
+
149
+ # return np.divide(numerator, denominator)
150
+
151
+ def perform_cpa(self):
152
+ """执行相关系数分析"""
153
+ store = zarr.DirectoryStore(self.output_path)
154
+ root = zarr.group(store=store, overwrite=True)
155
+
156
+ correlation = np.zeros((256, 16, self.sel_num_samples), dtype=np.float32)
157
+
158
+ traces = self.t[:self.sel_num_traces, self.sample_range[0]:self.sample_range[1]]
159
+
160
+ if self.dr == 'enc':
161
+ plaintext_bytes = self.plaintext[:self.sel_num_traces, :16]
162
+ elif self.dr == 'dec':
163
+ ciphertext_bytes = self.plaintext[:self.sel_num_traces, :16].compute()
164
+
165
+ # 使用@dask.delayed装饰器包装NumPy计算
166
+ @delayed
167
+ def compute_correlation(key_byte):
168
+ key_byte_array = np.full((self.sel_num_traces, 16), key_byte, dtype=np.uint8)
169
+ if self.dr == 'enc':
170
+ xor_result = np.bitwise_xor(plaintext_bytes, key_byte_array)
171
+ sbox_output = AES_SBOX[xor_result]
172
+ hw_matrix = WEIGHTS[sbox_output]
173
+ elif self.dr == 'dec':
174
+ xor_result = np.bitwise_xor(ciphertext_bytes, key_byte_array)
175
+ sbox_output = np.bitwise_xor(AES_invSBOX[xor_result], inv_shift_rows(ciphertext_bytes))
176
+ hw_matrix = WEIGHTS[sbox_output]
177
+
178
+ trace_count = self.sel_num_traces
179
+
180
+ # 使用NumPy直接计算,确保数学正确性
181
+ sample_mean = np.mean(traces, axis=0)
182
+ model_mean = np.mean(hw_matrix, axis=0)
183
+
184
+ # 向量化计算协方差矩阵 - 大幅优化性能
185
+ # 计算 hw_matrix 和 traces 的乘积和
186
+ prod_sum = np.dot(hw_matrix.T, traces)
187
+
188
+ # 计算协方差矩阵
189
+ cov_matrix = (prod_sum / trace_count) - np.outer(model_mean, sample_mean)
190
+
191
+ # 计算标准差
192
+ model_var = np.maximum(0, np.var(hw_matrix, axis=0))
193
+ sample_var = np.maximum(0, np.var(traces, axis=0))
194
+
195
+ model_std = np.sqrt(model_var)
196
+ sample_std = np.sqrt(sample_var)
197
+
198
+ # 向量化计算相关系数
199
+ denominator = np.outer(model_std, sample_std)
200
+ denominator[denominator == 0] = 1 # 避免除零
201
+
202
+ correlation_result = cov_matrix / denominator
203
+
204
+ return correlation_result
205
+
206
+ # 并行计算所有密钥字节
207
+ delayed_results = [compute_correlation(kb) for kb in range(256)]
208
+ with ProgressBar():
209
+ futures = delayed_results
210
+
211
+ # 收集结果
212
+ for key_byte, result in enumerate(futures):
213
+ correlation[key_byte] = result.compute()
214
+
215
+ # 优化后的候选值分析(向量化操作)
216
+ max_indices = np.argmax(np.abs(correlation), axis=2)
217
+ candidates = np.take_along_axis(correlation, max_indices[:, :, np.newaxis], axis=2).squeeze()
218
+
219
+ for j in self.byte_pos:
220
+ # 获取当前字节的最优候选
221
+ best_key = np.abs(candidates[:, j]).argmax()
222
+ print(f'第{j+1}字节密钥: {hex(best_key)} 相关系数: {candidates[best_key, j]:.4f}')
223
+
224
+ # 获取前5候选(向量化版本)
225
+ top5_indices = np.argsort(np.abs(candidates[:, j]))[::-1][:5]
226
+ for rank, idx in enumerate(top5_indices, 1):
227
+ print(f'第{rank}候选值:{hex(idx)},相关系数:{candidates[idx, j]:.4f}')
228
+ print('\n')
229
+
230
+ root.create_dataset(
231
+ '/0/0/correlation',
232
+ data = correlation[:, self.byte_pos, :]
233
+ )
234
+ # 添加元数据
235
+ root.attrs.update({
236
+ "cpa_metadata": {
237
+ "analyzed_byte": self.byte_pos,
238
+ "sample_range": self.sample_range,
239
+ "trace_count": self.sel_num_traces
240
+ }
241
+ })
242
+
243
+ if __name__ == "__main__":
244
+ # 示例用法
245
+ cpa = CPAAnalysis(input_path=r'E:\\codes\\Acquisition\\dataset\\20250909100916.zarr', dr='enc')
246
+ # cpa = CPAAnalysis(input_path=r'E:\\codes\\template\\dataset\\nut476_aes_random_data.zarr', dr='enc')
247
+
248
+ cpa.auto_out_filename()
249
+ # cpa.set_range(sample_range=(4000, 5200)) # 设置分析的采样点范围
250
+ cpa.perform_cpa()