sm-crypto-v2 1.4.0 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "sm-crypto-v2",
3
- "version": "1.4.0",
3
+ "version": "1.5.0",
4
4
  "description": "sm-crypto-v2",
5
5
  "main": "dist/index.js",
6
6
  "module": "dist/index.mjs",
package/pnpm-lock.yaml CHANGED
@@ -1,4 +1,8 @@
1
- lockfileVersion: '6.0'
1
+ lockfileVersion: '6.1'
2
+
3
+ settings:
4
+ autoInstallPeers: true
5
+ excludeLinksFromLockfile: false
2
6
 
3
7
  dependencies:
4
8
  '@noble/curves':
package/src/index.ts CHANGED
@@ -1,3 +1,3 @@
1
1
  export * as sm2 from './sm2/index'
2
- export * as sm3 from './sm3/index'
3
- export * as sm4 from './sm4/index'
2
+ export { sm3 } from './sm3/index'
3
+ export * as sm4 from './sm4/index'
package/src/sm2/ec.ts CHANGED
@@ -1,89 +1,10 @@
1
1
  import { weierstrass } from '@noble/curves/abstract/weierstrass';
2
2
  import { Field } from '@noble/curves/abstract/modular'; // finite field for mod arithmetics
3
- import { hmac, sm3 } from './sm3'
4
- import { utf8ToArray } from '@/sm3';
5
- import { concatArray } from './utils';
6
3
  import { ONE } from './bn';
7
-
8
- /**
9
- * 安全随机数发生器
10
- * 如果有原生同步接口,直接使用。否则维护一个随机数池,使用异步接口实现。
11
- * Web: webcrypto 原生同步接口
12
- * Node: Node v18 之前需要引入 crypto 模块,这里使用异步 import
13
- * 小程序:异步接口
14
- */
15
- declare module wx {
16
- function getRandomValues(options: {
17
- length: number;
18
- success: (res: { randomValues: ArrayBuffer }) => void;
19
- }): void;
20
- }
21
-
22
- const DEFAULT_PRNG_POOL_SIZE = 16384
23
- let prngPool = new Uint8Array(0)
24
- let _syncCrypto: typeof import('crypto')['webcrypto']
25
- export async function initRNGPool() {
26
- if ('crypto' in globalThis) {
27
- _syncCrypto = globalThis.crypto
28
- return // no need to use pooling
29
- }
30
- if (prngPool.length > DEFAULT_PRNG_POOL_SIZE / 2) return // there is sufficient number
31
- // we always populate full pool size
32
- // since numbers may be consumed during micro tasks.
33
- if ('wx' in globalThis) {
34
- prngPool = await new Promise(r => {
35
- wx.getRandomValues({
36
- length: DEFAULT_PRNG_POOL_SIZE,
37
- success(res) {
38
- r(new Uint8Array(res.randomValues));
39
- }
40
- });
41
- });
42
- } else {
43
- // check if node, use webcrypto if available
44
- try {
45
- const crypto = await import(/* webpackIgnore: true */ 'crypto');
46
- _syncCrypto = crypto.webcrypto
47
- const array = new Uint8Array(DEFAULT_PRNG_POOL_SIZE);
48
- _syncCrypto.getRandomValues(array);
49
- prngPool = array;
50
- } catch (error) {
51
- throw new Error('no available csprng, abort.');
52
- }
53
- }
54
- }
55
-
56
- initRNGPool()
57
-
58
- function consumePool(length: number): Uint8Array {
59
- if (prngPool.length > length) {
60
- const prng = prngPool.slice(0, length)
61
- prngPool = prngPool.slice(length)
62
- initRNGPool()
63
- return prng
64
- } else {
65
- throw new Error('random number pool is not ready or insufficient, prevent getting too long random values or too often.')
66
- }
67
- }
68
-
69
- export function randomBytes(length = 0): Uint8Array {
70
- const array = new Uint8Array(length);
71
- if (_syncCrypto) {
72
- return _syncCrypto.getRandomValues(array);
73
- } else {
74
- // no sync crypto available, use async pool
75
- const result = consumePool(length)
76
- return result
77
- }
78
- }
79
-
80
- export function createHash() {
81
- const hashC = (msg: Uint8Array | string): Uint8Array => sm3(typeof msg === 'string' ? utf8ToArray(msg) : msg)
82
- hashC.outputLen = 256;
83
- hashC.blockLen = 512;
84
- hashC.create = () => sm3(Uint8Array.from([]));
85
- return hashC;
86
- }
4
+ import { randomBytes } from './rng';
5
+ import { sm3 } from './sm3';
6
+ import { hmac } from './hmac';
7
+ import { concatBytes } from '@noble/curves/abstract/utils';
87
8
 
88
9
  export const sm2Fp = Field(BigInt('115792089210356248756420345214020892766250353991924191454421193933289684991999'))
89
10
  export const sm2Curve = weierstrass({
@@ -95,7 +16,9 @@ export const sm2Curve = weierstrass({
95
16
  n: BigInt('115792089210356248756420345214020892766061623724957744567843809356293439045923'),
96
17
  Gx: BigInt('22963146547237050559479531362550074578802567295341616970375194840604139615431'),
97
18
  Gy: BigInt('85132369209828568825618990617112496413088388631904505083283536607588877201568'),
98
- hash: createHash(),
99
- hmac: (key: Uint8Array, ...msgs: Uint8Array[]) => hmac(concatArray(...msgs), key),
19
+ hash: sm3,
20
+ hmac: (key: Uint8Array, ...msgs: Uint8Array[]) => hmac(sm3, key, concatBytes(...msgs)),
100
21
  randomBytes,
101
22
  });
23
+ // 有限域运算
24
+ export const field = Field(BigInt(sm2Curve.CURVE.n))
@@ -0,0 +1,76 @@
1
+ import { Hash, CHash, Input, toBytes } from '../sm3/utils';
2
+ // HMAC (RFC 2104)
3
+ export class HMAC<T extends Hash<T>> extends Hash<HMAC<T>> {
4
+ oHash: T;
5
+ iHash: T;
6
+ blockLen: number;
7
+ outputLen: number;
8
+ private finished = false;
9
+ private destroyed = false;
10
+
11
+ constructor(hash: CHash, _key: Input) {
12
+ super();
13
+ const key = toBytes(_key);
14
+ this.iHash = hash.create() as T;
15
+ if (typeof this.iHash.update !== 'function')
16
+ throw new Error('Expected instance of class which extends utils.Hash');
17
+ this.blockLen = this.iHash.blockLen;
18
+ this.outputLen = this.iHash.outputLen;
19
+ const blockLen = this.blockLen;
20
+ const pad = new Uint8Array(blockLen);
21
+ // blockLen can be bigger than outputLen
22
+ pad.set(key.length > blockLen ? hash.create().update(key).digest() : key);
23
+ for (let i = 0; i < pad.length; i++) pad[i] ^= 0x36;
24
+ this.iHash.update(pad);
25
+ // By doing update (processing of first block) of outer hash here we can re-use it between multiple calls via clone
26
+ this.oHash = hash.create() as T;
27
+ // Undo internal XOR && apply outer XOR
28
+ for (let i = 0; i < pad.length; i++) pad[i] ^= 0x36 ^ 0x5c;
29
+ this.oHash.update(pad);
30
+ pad.fill(0);
31
+ }
32
+ update(buf: Input) {
33
+ this.iHash.update(buf);
34
+ return this;
35
+ }
36
+ digestInto(out: Uint8Array) {
37
+ this.finished = true;
38
+ this.iHash.digestInto(out);
39
+ this.oHash.update(out);
40
+ this.oHash.digestInto(out);
41
+ this.destroy();
42
+ }
43
+ digest() {
44
+ const out = new Uint8Array(this.oHash.outputLen);
45
+ this.digestInto(out);
46
+ return out;
47
+ }
48
+ _cloneInto(to?: HMAC<T>): HMAC<T> {
49
+ // Create new instance without calling constructor since key already in state and we don't know it.
50
+ to ||= Object.create(Object.getPrototypeOf(this), {});
51
+ const { oHash, iHash, finished, destroyed, blockLen, outputLen } = this;
52
+ to = to as this;
53
+ to.finished = finished;
54
+ to.destroyed = destroyed;
55
+ to.blockLen = blockLen;
56
+ to.outputLen = outputLen;
57
+ to.oHash = oHash._cloneInto(to.oHash);
58
+ to.iHash = iHash._cloneInto(to.iHash);
59
+ return to;
60
+ }
61
+ destroy() {
62
+ this.destroyed = true;
63
+ this.oHash.destroy();
64
+ this.iHash.destroy();
65
+ }
66
+ }
67
+
68
+ /**
69
+ * HMAC: RFC2104 message authentication code.
70
+ * @param hash - function that would be used e.g. sha256
71
+ * @param key - message key
72
+ * @param message - message data
73
+ */
74
+ export const hmac = (hash: CHash, key: Input, message: Input): Uint8Array =>
75
+ new HMAC<any>(hash, key).update(message).digest();
76
+ hmac.create = (hash: CHash, key: Input) => new HMAC<any>(hash, key);
package/src/sm2/index.ts CHANGED
@@ -1,19 +1,19 @@
1
1
  /* eslint-disable no-use-before-define */
2
2
  import { encodeDer, decodeDer } from './asn1'
3
- import { arrayToHex, arrayToUtf8, concatArray, generateKeyPairHex, hexToArray, leftPad, utf8ToHex } from './utils'
3
+ import { arrayToHex, arrayToUtf8, generateKeyPairHex, hexToArray, leftPad, utf8ToHex } from './utils'
4
4
  import { sm3 } from './sm3'
5
- import * as mod from '@noble/curves/abstract/modular';
6
5
  import * as utils from '@noble/curves/abstract/utils';
7
- import { sm2Curve } from './ec';
6
+ import { field, sm2Curve } from './ec';
8
7
  import { ONE, ZERO } from './bn';
8
+ import { bytesToHex } from '@/sm3/utils';
9
9
 
10
10
  export * from './utils'
11
- export { initRNGPool } from './ec'
11
+ export { initRNGPool } from './rng'
12
12
  export { calculateSharedKey } from './kx'
13
13
 
14
- // const { G, curve, n } = generateEcparam()
15
14
  const C1C2C3 = 0
16
-
15
+ // a empty array, just make tsc happy
16
+ export const EmptyArray = new Uint8Array()
17
17
  /**
18
18
  * 加密
19
19
  */
@@ -21,11 +21,9 @@ export function doEncrypt(msg: string | Uint8Array, publicKey: string, cipherMod
21
21
 
22
22
  const msgArr = typeof msg === 'string' ? hexToArray(utf8ToHex(msg)) : Uint8Array.from(msg)
23
23
  const publicKeyPoint = sm2Curve.ProjectivePoint.fromHex(publicKey)
24
- // const publicKeyPoint = getGlobalCurve().decodePointHex(publicKey) // 先将公钥转成点
25
24
 
26
25
  const keypair = generateKeyPairHex()
27
26
  const k = utils.hexToNumber(keypair.privateKey)
28
- // const k = new BigInteger(keypair.privateKey, 16) // 随机数 k
29
27
 
30
28
  // c1 = k * G
31
29
  let c1 = keypair.publicKey
@@ -37,31 +35,40 @@ export function doEncrypt(msg: string | Uint8Array, publicKey: string, cipherMod
37
35
  const y2 = hexToArray(leftPad(utils.numberToHexUnpadded(p.y), 64))
38
36
 
39
37
  // c3 = hash(x2 || msg || y2)
40
- const c3 = arrayToHex(Array.from(sm3(concatArray(x2, msgArr, y2))));
38
+ const c3 = bytesToHex(sm3(utils.concatBytes(x2, msgArr, y2)));
39
+
40
+ xorCipherStream(x2, y2, msgArr)
41
+ const c2 = bytesToHex(msgArr)
42
+
43
+ return cipherMode === C1C2C3 ? c1 + c2 + c3 : c1 + c3 + c2
44
+ }
41
45
 
46
+ function xorCipherStream(x2: Uint8Array, y2: Uint8Array, msg: Uint8Array) {
42
47
  let ct = 1
43
48
  let offset = 0
44
- let t: Uint8Array = new Uint8Array() // 256 位
45
- const z = concatArray(x2, y2)
49
+ let t = EmptyArray
50
+ const ctShift = new Uint8Array(4)
46
51
  const nextT = () => {
47
52
  // (1) Hai = hash(z || ct)
48
53
  // (2) ct++
49
- t = sm3(Uint8Array.from([...z, ct >> 24 & 0x00ff, ct >> 16 & 0x00ff, ct >> 8 & 0x00ff, ct & 0x00ff]))
54
+ ctShift[0] = ct >> 24 & 0x00ff
55
+ ctShift[1] = ct >> 16 & 0x00ff
56
+ ctShift[2] = ct >> 8 & 0x00ff
57
+ ctShift[3] = ct & 0x00ff
58
+ t = sm3(utils.concatBytes(x2, y2, ctShift))
50
59
  ct++
51
60
  offset = 0
52
61
  }
53
62
  nextT() // 先生成 Ha1
54
63
 
55
- for (let i = 0, len = msgArr.length; i < len; i++) {
64
+ for (let i = 0, len = msg.length; i < len; i++) {
56
65
  // t = Ha1 || Ha2 || Ha3 || Ha4
57
66
  if (offset === t.length) nextT()
58
67
 
59
68
  // c2 = msg ^ t
60
- msgArr[i] ^= t[offset++] & 0xff
69
+ msg[i] ^= t[offset++] & 0xff
61
70
  }
62
- const c2 = arrayToHex(Array.from(msgArr))
63
71
 
64
- return cipherMode === C1C2C3 ? c1 + c2 + c3 : c1 + c3 + c2
65
72
  }
66
73
 
67
74
  /**
@@ -76,7 +83,6 @@ export function doDecrypt(encryptData: string, privateKey: string, cipherMode?:
76
83
  export function doDecrypt(encryptData: string, privateKey: string, cipherMode = 1, {
77
84
  output = 'string',
78
85
  } = {}) {
79
- // const privateKeyInteger = new BigInteger(privateKey, 16)
80
86
  const privateKeyInteger = utils.hexToNumber(privateKey)
81
87
 
82
88
  let c3 = encryptData.substring(128, 128 + 64)
@@ -88,36 +94,15 @@ export function doDecrypt(encryptData: string, privateKey: string, cipherMode =
88
94
  }
89
95
 
90
96
  const msg = hexToArray(c2)
91
- // const c1 = getGlobalCurve().decodePointHex('04' + encryptData.substring(0, 128))!
92
97
  const c1 = sm2Curve.ProjectivePoint.fromHex('04' + encryptData.substring(0, 128))!
93
98
 
94
99
  const p = c1.multiply(privateKeyInteger)
95
- // const x2 = hexToArray(leftPad(p.getX().toBigInteger().toRadix(16), 64))
96
- // const y2 = hexToArray(leftPad(p.getY().toBigInteger().toRadix(16), 64))
97
100
  const x2 = hexToArray(leftPad(utils.numberToHexUnpadded(p.x), 64))
98
101
  const y2 = hexToArray(leftPad(utils.numberToHexUnpadded(p.y), 64))
99
- let ct = 1
100
- let offset = 0
101
- let t = new Uint8Array() // 256 位
102
- const z = concatArray(x2, y2)
103
- const nextT = () => {
104
- // (1) Hai = hash(z || ct)
105
- // (2) ct++
106
- t = sm3(Uint8Array.from([...z, ct >> 24 & 0x00ff, ct >> 16 & 0x00ff, ct >> 8 & 0x00ff, ct & 0x00ff]))
107
- ct++
108
- offset = 0
109
- }
110
- nextT() // 先生成 Ha1
111
-
112
- for (let i = 0, len = msg.length; i < len; i++) {
113
- // t = Ha1 || Ha2 || Ha3 || Ha4
114
- if (offset === t.length) nextT()
115
102
 
116
- // c2 = msg ^ t
117
- msg[i] ^= t[offset++] & 0xff
118
- }
103
+ xorCipherStream(x2, y2, msg)
119
104
  // c3 = hash(x2 || msg || y2)
120
- const checkC3 = arrayToHex(Array.from(sm3(concatArray(x2, msg, y2))))
105
+ const checkC3 = arrayToHex(Array.from(sm3(utils.concatBytes(x2, msg, y2))))
121
106
 
122
107
  if (checkC3 === c3.toLowerCase()) {
123
108
  return output === 'array' ? msg : arrayToUtf8(msg)
@@ -167,13 +152,11 @@ export function doSignature(msg: Uint8Array | string, privateKey: string, option
167
152
  k = point.k
168
153
 
169
154
  // r = (e + x1) mod n
170
- // r = e.add(point.x1).mod(n)
171
- r = mod.mod(e + point.x1, sm2Curve.CURVE.n)
155
+ r = field.add(e, point.x1)
172
156
  } while (r === ZERO || (r + k) === sm2Curve.CURVE.n)
173
157
 
174
158
  // s = ((1 + dA)^-1 * (k - r * dA)) mod n
175
- // s = dA.add(BigInteger.ONE).modInverse(n).multiply(k.subtract(r.multiply(dA))).mod(n)
176
- s = mod.mod(mod.invert(dA + ONE, sm2Curve.CURVE.n) * (k - r * dA), sm2Curve.CURVE.n)
159
+ s = field.mul(field.inv(field.addN(dA, ONE)), field.subN(k, field.mulN(r, dA)))
177
160
  } while (s === ZERO)
178
161
  if (der) return encodeDer(r, s) // asn.1 der 编码
179
162
  return leftPad(utils.numberToHexUnpadded(r), 64) + leftPad(utils.numberToHexUnpadded(s), 64)
@@ -203,30 +186,24 @@ export function doVerifySignature(msg: string | Uint8Array, signHex: string, pub
203
186
  r = decodeDerObj.r
204
187
  s = decodeDerObj.s
205
188
  } else {
206
- // r = new BigInteger(signHex.substring(0, 64), 16)
207
- // s = new BigInteger(signHex.substring(64), 16)
208
189
  r = utils.hexToNumber(signHex.substring(0, 64))
209
190
  s = utils.hexToNumber(signHex.substring(64))
210
191
  }
211
192
 
212
- // const PA = curve.decodePointHex(publicKey)!
213
193
  const PA = sm2Curve.ProjectivePoint.fromHex(publicKey)!
214
- // const e = new BigInteger(hashHex, 16)
215
194
  const e = utils.hexToNumber(hashHex)
216
195
 
217
196
  // t = (r + s) mod n
218
- // const t = r.add(s).mod(n)
219
- const t = mod.mod(r + s, sm2Curve.CURVE.n)
197
+ const t = field.add(r, s)
220
198
 
221
199
  if (t === ZERO) return false
222
200
 
223
201
  // x1y1 = s * G + t * PA
224
- // const x1y1 = G.multiply(s).add(PA.multiply(t))
225
202
  const x1y1 = sm2Curve.ProjectivePoint.BASE.multiply(s).add(PA.multiply(t))
226
203
 
227
204
  // R = (e + x1) mod n
228
205
  // const R = e.add(x1y1.getX().toBigInteger()).mod(n)
229
- const R = mod.mod(e + x1y1.x, sm2Curve.CURVE.n)
206
+ const R = field.add(e, x1y1.x)
230
207
 
231
208
  // return r.equals(R)
232
209
  return r === R
@@ -262,10 +239,10 @@ export function getHash(hashHex: string | Uint8Array, publicKey: string, userId
262
239
 
263
240
  const entl = userId.length * 4
264
241
 
265
- const z = sm3(concatArray(new Uint8Array([entl >> 8 & 0x00ff, entl & 0x00ff]), data))
242
+ const z = sm3(utils.concatBytes(new Uint8Array([entl >> 8 & 0x00ff, entl & 0x00ff]), data))
266
243
 
267
244
  // e = hash(z || msg)
268
- return arrayToHex(Array.from(sm3(concatArray(z, typeof hashHex === 'string' ? hexToArray(hashHex) : hashHex))))
245
+ return bytesToHex(sm3(utils.concatBytes(z, typeof hashHex === 'string' ? hexToArray(hashHex) : hashHex)))
269
246
  }
270
247
 
271
248
  /**
@@ -275,10 +252,6 @@ export function getPublicKeyFromPrivateKey(privateKey: string) {
275
252
  const pubKey = sm2Curve.getPublicKey(privateKey, false)
276
253
  const pubPad = leftPad(utils.bytesToHex(pubKey), 64)
277
254
  return pubPad
278
- // const PA = G.multiply(new BigInteger(privateKey, 16))
279
- // const x = leftPad(PA.getX().toBigInteger().toString(16), 64)
280
- // const y = leftPad(PA.getY().toBigInteger().toString(16), 64)
281
- // return '04' + x + y
282
255
  }
283
256
 
284
257
  /**
@@ -286,7 +259,6 @@ export function getPublicKeyFromPrivateKey(privateKey: string) {
286
259
  */
287
260
  export function getPoint() {
288
261
  const keypair = generateKeyPairHex()
289
- // const PA = curve.decodePointHex(keypair.publicKey)
290
262
  const PA = sm2Curve.ProjectivePoint.fromHex(keypair.publicKey)
291
263
  const k = utils.hexToNumber(keypair.privateKey)
292
264
 
package/src/sm2/kx.ts CHANGED
@@ -1,10 +1,9 @@
1
- import { sm2Curve } from './ec';
2
- import { KeyPair, concatArray, hexToArray, leftPad } from './utils';
1
+ import { field, sm2Curve } from './ec';
2
+ import { KeyPair, hexToArray, leftPad } from './utils';
3
3
  import * as utils from '@noble/curves/abstract/utils';
4
- import { Field } from '@noble/curves/abstract/modular';
5
4
  import { sm3 } from './sm3';
5
+ import { EmptyArray } from '.';
6
6
 
7
- export const field = Field(BigInt(sm2Curve.CURVE.n))
8
7
 
9
8
  // 用到的常数
10
9
  const wPow2 = utils.hexToNumber('80000000000000000000000000000000')
@@ -12,14 +11,19 @@ const wPow2Sub1 = utils.hexToNumber('7fffffffffffffffffffffffffffffff')
12
11
 
13
12
  // from sm2 sign part, extracted for code reusable.
14
13
  function hkdf(z: Uint8Array, keylen: number) {
15
- let t = new Uint8Array() // 256 位
16
14
  let msg = new Uint8Array(keylen)
17
15
  let ct = 1
18
16
  let offset = 0
17
+ let t = EmptyArray
18
+ const ctShift = new Uint8Array(4)
19
19
  const nextT = () => {
20
20
  // (1) Hai = hash(z || ct)
21
21
  // (2) ct++
22
- t = sm3(Uint8Array.from([...z, ct >> 24 & 0x00ff, ct >> 16 & 0x00ff, ct >> 8 & 0x00ff, ct & 0x00ff]))
22
+ ctShift[0] = ct >> 24 & 0x00ff
23
+ ctShift[1] = ct >> 16 & 0x00ff
24
+ ctShift[2] = ct >> 8 & 0x00ff
25
+ ctShift[3] = ct & 0x00ff
26
+ t = sm3(utils.concatBytes(z, ctShift))
23
27
  ct++
24
28
  offset = 0
25
29
  }
@@ -29,7 +33,7 @@ function hkdf(z: Uint8Array, keylen: number) {
29
33
  // t = Ha1 || Ha2 || Ha3 || Ha4
30
34
  if (offset === t.length) nextT()
31
35
 
32
- // c2 = msg ^ t
36
+ // 输出 stream
33
37
  msg[i] = t[offset++] & 0xff
34
38
  }
35
39
  return msg
@@ -47,18 +51,18 @@ export function calculateSharedKey(
47
51
  ) {
48
52
  const RA = sm2Curve.ProjectivePoint.fromHex(ephemeralKeypairA.publicKey)
49
53
  const RB = sm2Curve.ProjectivePoint.fromHex(ephemeralPublicKeyB)
50
- // const PA = sm2Curve.ProjectivePoint.fromHex(keypairA.publicKey) // 暂时用不到
54
+ // const PA = sm2Curve.ProjectivePoint.fromHex(keypairA.publicKey) // 用不到
51
55
  const PB = sm2Curve.ProjectivePoint.fromHex(publicKeyB)
52
56
  const ZA = hexToArray(idA)
53
57
  const ZB = hexToArray(idB)
54
58
  const rA = utils.hexToNumber(ephemeralKeypairA.privateKey)
55
59
  const dA = utils.hexToNumber(keypairA.privateKey)
56
- // 1.先算tA
60
+ // 1.先算 tA
57
61
  const x1 = RA.x
58
62
  // x1_ = 2^w + (x1 & (2^w - 1))
59
63
  const x1_ = field.add(wPow2, (x1 & wPow2Sub1))
60
64
  // tA = (dA + x1b * rA) mod n
61
- const tA = field.add(dA, field.mul(x1_, rA))
65
+ const tA = field.add(dA, field.mulN(x1_, rA))
62
66
 
63
67
  // 2.算 U
64
68
  // x2_ = 2^w + (x2 & (2^w - 1))
@@ -71,6 +75,6 @@ export function calculateSharedKey(
71
75
  // KA = KDF(xU || yU || ZA || ZB, kLen)
72
76
  const xU = hexToArray(leftPad(utils.numberToHexUnpadded(U.x), 64))
73
77
  const yU = hexToArray(leftPad(utils.numberToHexUnpadded(U.y), 64))
74
- const KA = hkdf(concatArray(xU, yU, ZA, ZB), sharedKeyLength)
78
+ const KA = hkdf(utils.concatBytes(xU, yU, ZA, ZB), sharedKeyLength)
75
79
  return KA
76
80
  }
package/src/sm2/rng.ts ADDED
@@ -0,0 +1,71 @@
1
+ /**
2
+ * 安全随机数发生器
3
+ * 如果有原生同步接口,直接使用。否则维护一个随机数池,使用异步接口实现。
4
+ * Web: webcrypto 原生同步接口
5
+ * Node: Node v18 之前需要引入 crypto 模块,这里使用异步 import
6
+ * 小程序:异步接口
7
+ */
8
+ declare module wx {
9
+ function getRandomValues(options: {
10
+ length: number;
11
+ success: (res: { randomValues: ArrayBuffer }) => void;
12
+ }): void;
13
+ }
14
+
15
+ const DEFAULT_PRNG_POOL_SIZE = 16384
16
+ let prngPool = new Uint8Array(0)
17
+ let _syncCrypto: typeof import('crypto')['webcrypto']
18
+ export async function initRNGPool() {
19
+ if ('crypto' in globalThis) {
20
+ _syncCrypto = globalThis.crypto
21
+ return // no need to use pooling
22
+ }
23
+ if (prngPool.length > DEFAULT_PRNG_POOL_SIZE / 2) return // there is sufficient number
24
+ // we always populate full pool size
25
+ // since numbers may be consumed during micro tasks.
26
+ if ('wx' in globalThis) {
27
+ prngPool = await new Promise(r => {
28
+ wx.getRandomValues({
29
+ length: DEFAULT_PRNG_POOL_SIZE,
30
+ success(res) {
31
+ r(new Uint8Array(res.randomValues));
32
+ }
33
+ });
34
+ });
35
+ } else {
36
+ // check if node, use webcrypto if available
37
+ try {
38
+ const crypto = await import(/* webpackIgnore: true */ 'crypto');
39
+ _syncCrypto = crypto.webcrypto
40
+ const array = new Uint8Array(DEFAULT_PRNG_POOL_SIZE);
41
+ _syncCrypto.getRandomValues(array);
42
+ prngPool = array;
43
+ } catch (error) {
44
+ throw new Error('no available csprng, abort.');
45
+ }
46
+ }
47
+ }
48
+
49
+ initRNGPool()
50
+
51
+ function consumePool(length: number): Uint8Array {
52
+ if (prngPool.length > length) {
53
+ const prng = prngPool.slice(0, length)
54
+ prngPool = prngPool.slice(length)
55
+ initRNGPool()
56
+ return prng
57
+ } else {
58
+ throw new Error('random number pool is not ready or insufficient, prevent getting too long random values or too often.')
59
+ }
60
+ }
61
+
62
+ export function randomBytes(length = 0): Uint8Array {
63
+ const array = new Uint8Array(length);
64
+ if (_syncCrypto) {
65
+ return _syncCrypto.getRandomValues(array);
66
+ } else {
67
+ // no sync crypto available, use async pool
68
+ const result = consumePool(length)
69
+ return result
70
+ }
71
+ }