@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.js +14 -0
  135. package/dist/{core/task → task}/display_information.d.ts +5 -3
  136. package/dist/task/display_information.js +46 -0
  137. package/dist/task/index.d.ts +7 -0
  138. package/dist/task/index.js +5 -0
  139. package/dist/task/label_type.d.ts +9 -0
  140. package/dist/task/label_type.js +28 -0
  141. package/dist/task/summary.js +13 -0
  142. package/dist/{core/task → task}/task.d.ts +7 -7
  143. package/dist/task/task.js +22 -0
  144. package/dist/task/task_handler.d.ts +5 -0
  145. package/dist/task/task_handler.js +20 -0
  146. package/dist/task/task_provider.d.ts +5 -0
  147. package/dist/task/task_provider.js +1 -0
  148. package/dist/{core/task → task}/training_information.d.ts +9 -10
  149. package/dist/task/training_information.js +88 -0
  150. package/dist/training/disco.d.ts +40 -0
  151. package/dist/training/disco.js +107 -0
  152. package/dist/training/index.d.ts +2 -0
  153. package/dist/training/index.js +1 -0
  154. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  155. package/dist/training/trainer/distributed_trainer.js +36 -0
  156. package/dist/training/trainer/local_trainer.d.ts +12 -0
  157. package/dist/training/trainer/local_trainer.js +19 -0
  158. package/dist/training/trainer/trainer.d.ts +33 -0
  159. package/dist/training/trainer/trainer.js +52 -0
  160. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  161. package/dist/training/trainer/trainer_builder.js +43 -0
  162. package/dist/types.d.ts +8 -0
  163. package/dist/types.js +1 -0
  164. package/dist/utils/event_emitter.d.ts +40 -0
  165. package/dist/utils/event_emitter.js +57 -0
  166. package/dist/validation/index.d.ts +1 -0
  167. package/dist/validation/index.js +1 -0
  168. package/dist/validation/validator.d.ts +28 -0
  169. package/dist/validation/validator.js +132 -0
  170. package/dist/weights/aggregation.d.ts +21 -0
  171. package/dist/weights/aggregation.js +44 -0
  172. package/dist/weights/index.d.ts +2 -0
  173. package/dist/weights/index.js +2 -0
  174. package/dist/weights/weights_container.d.ts +68 -0
  175. package/dist/weights/weights_container.js +96 -0
  176. package/package.json +24 -15
  177. package/README.md +0 -53
  178. package/dist/core/async_buffer.d.ts +0 -41
  179. package/dist/core/async_buffer.js +0 -97
  180. package/dist/core/async_informant.d.ts +0 -20
  181. package/dist/core/async_informant.js +0 -69
  182. package/dist/core/client/base.d.ts +0 -33
  183. package/dist/core/client/base.js +0 -35
  184. package/dist/core/client/decentralized/base.d.ts +0 -32
  185. package/dist/core/client/decentralized/base.js +0 -212
  186. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  187. package/dist/core/client/decentralized/clear_text.js +0 -96
  188. package/dist/core/client/decentralized/index.d.ts +0 -4
  189. package/dist/core/client/decentralized/index.js +0 -9
  190. package/dist/core/client/decentralized/messages.d.ts +0 -41
  191. package/dist/core/client/decentralized/messages.js +0 -54
  192. package/dist/core/client/decentralized/peer.d.ts +0 -26
  193. package/dist/core/client/decentralized/peer.js +0 -210
  194. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  195. package/dist/core/client/decentralized/peer_pool.js +0 -92
  196. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  197. package/dist/core/client/decentralized/sec_agg.js +0 -190
  198. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  199. package/dist/core/client/decentralized/secret_shares.js +0 -39
  200. package/dist/core/client/decentralized/types.d.ts +0 -2
  201. package/dist/core/client/decentralized/types.js +0 -7
  202. package/dist/core/client/event_connection.d.ts +0 -37
  203. package/dist/core/client/event_connection.js +0 -158
  204. package/dist/core/client/federated/client.d.ts +0 -37
  205. package/dist/core/client/federated/client.js +0 -273
  206. package/dist/core/client/federated/index.d.ts +0 -2
  207. package/dist/core/client/federated/index.js +0 -7
  208. package/dist/core/client/federated/messages.d.ts +0 -38
  209. package/dist/core/client/federated/messages.js +0 -25
  210. package/dist/core/client/index.d.ts +0 -5
  211. package/dist/core/client/index.js +0 -11
  212. package/dist/core/client/local.d.ts +0 -8
  213. package/dist/core/client/local.js +0 -36
  214. package/dist/core/client/messages.d.ts +0 -28
  215. package/dist/core/client/messages.js +0 -33
  216. package/dist/core/client/utils.d.ts +0 -2
  217. package/dist/core/client/utils.js +0 -19
  218. package/dist/core/dataset/data/data.d.ts +0 -11
  219. package/dist/core/dataset/data/data.js +0 -20
  220. package/dist/core/dataset/data/data_split.d.ts +0 -5
  221. package/dist/core/dataset/data/data_split.js +0 -2
  222. package/dist/core/dataset/data/image_data.d.ts +0 -8
  223. package/dist/core/dataset/data/image_data.js +0 -64
  224. package/dist/core/dataset/data/index.d.ts +0 -5
  225. package/dist/core/dataset/data/index.js +0 -11
  226. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  227. package/dist/core/dataset/data/preprocessing.js +0 -33
  228. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  229. package/dist/core/dataset/data/tabular_data.js +0 -40
  230. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  231. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  232. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  233. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  234. package/dist/core/dataset/data_loader/index.js +0 -9
  235. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  236. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  237. package/dist/core/dataset/dataset.d.ts +0 -2
  238. package/dist/core/dataset/dataset.js +0 -2
  239. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  240. package/dist/core/dataset/dataset_builder.js +0 -96
  241. package/dist/core/dataset/index.d.ts +0 -4
  242. package/dist/core/dataset/index.js +0 -14
  243. package/dist/core/default_tasks/cifar10.d.ts +0 -2
  244. package/dist/core/default_tasks/geotags.d.ts +0 -2
  245. package/dist/core/default_tasks/geotags.js +0 -69
  246. package/dist/core/default_tasks/index.d.ts +0 -6
  247. package/dist/core/default_tasks/index.js +0 -15
  248. package/dist/core/default_tasks/lus_covid.d.ts +0 -2
  249. package/dist/core/default_tasks/lus_covid.js +0 -96
  250. package/dist/core/default_tasks/mnist.d.ts +0 -2
  251. package/dist/core/default_tasks/simple_face.d.ts +0 -2
  252. package/dist/core/default_tasks/titanic.d.ts +0 -2
  253. package/dist/core/index.d.ts +0 -18
  254. package/dist/core/index.js +0 -39
  255. package/dist/core/informant/graph_informant.js +0 -23
  256. package/dist/core/informant/index.d.ts +0 -3
  257. package/dist/core/informant/index.js +0 -9
  258. package/dist/core/informant/training_informant/base.d.ts +0 -31
  259. package/dist/core/informant/training_informant/base.js +0 -83
  260. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  261. package/dist/core/informant/training_informant/decentralized.js +0 -22
  262. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  263. package/dist/core/informant/training_informant/federated.js +0 -32
  264. package/dist/core/informant/training_informant/index.d.ts +0 -4
  265. package/dist/core/informant/training_informant/index.js +0 -11
  266. package/dist/core/informant/training_informant/local.d.ts +0 -6
  267. package/dist/core/informant/training_informant/local.js +0 -20
  268. package/dist/core/logging/console_logger.js +0 -33
  269. package/dist/core/logging/index.d.ts +0 -3
  270. package/dist/core/logging/index.js +0 -9
  271. package/dist/core/logging/logger.js +0 -9
  272. package/dist/core/logging/trainer_logger.d.ts +0 -24
  273. package/dist/core/logging/trainer_logger.js +0 -59
  274. package/dist/core/memory/base.d.ts +0 -22
  275. package/dist/core/memory/base.js +0 -9
  276. package/dist/core/memory/empty.d.ts +0 -14
  277. package/dist/core/memory/empty.js +0 -75
  278. package/dist/core/memory/index.d.ts +0 -3
  279. package/dist/core/memory/index.js +0 -9
  280. package/dist/core/memory/model_type.d.ts +0 -4
  281. package/dist/core/memory/model_type.js +0 -9
  282. package/dist/core/serialization/index.d.ts +0 -2
  283. package/dist/core/serialization/index.js +0 -6
  284. package/dist/core/serialization/model.d.ts +0 -5
  285. package/dist/core/serialization/model.js +0 -55
  286. package/dist/core/serialization/weights.js +0 -64
  287. package/dist/core/task/data_example.js +0 -24
  288. package/dist/core/task/digest.js +0 -18
  289. package/dist/core/task/display_information.js +0 -49
  290. package/dist/core/task/index.d.ts +0 -6
  291. package/dist/core/task/index.js +0 -15
  292. package/dist/core/task/model_compile_data.d.ts +0 -6
  293. package/dist/core/task/model_compile_data.js +0 -22
  294. package/dist/core/task/summary.js +0 -19
  295. package/dist/core/task/task.js +0 -35
  296. package/dist/core/task/task_handler.d.ts +0 -5
  297. package/dist/core/task/task_handler.js +0 -53
  298. package/dist/core/task/task_provider.d.ts +0 -6
  299. package/dist/core/task/task_provider.js +0 -13
  300. package/dist/core/task/training_information.js +0 -66
  301. package/dist/core/training/disco.d.ts +0 -23
  302. package/dist/core/training/disco.js +0 -130
  303. package/dist/core/training/index.d.ts +0 -2
  304. package/dist/core/training/index.js +0 -7
  305. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  306. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  307. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  308. package/dist/core/training/trainer/local_trainer.js +0 -34
  309. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  310. package/dist/core/training/trainer/round_tracker.js +0 -47
  311. package/dist/core/training/trainer/trainer.d.ts +0 -65
  312. package/dist/core/training/trainer/trainer.js +0 -160
  313. package/dist/core/training/trainer/trainer_builder.js +0 -95
  314. package/dist/core/training/training_schemes.d.ts +0 -5
  315. package/dist/core/training/training_schemes.js +0 -10
  316. package/dist/core/types.d.ts +0 -4
  317. package/dist/core/types.js +0 -2
  318. package/dist/core/validation/index.d.ts +0 -1
  319. package/dist/core/validation/index.js +0 -5
  320. package/dist/core/validation/validator.d.ts +0 -17
  321. package/dist/core/validation/validator.js +0 -104
  322. package/dist/core/weights/aggregation.d.ts +0 -7
  323. package/dist/core/weights/aggregation.js +0 -72
  324. package/dist/core/weights/index.d.ts +0 -2
  325. package/dist/core/weights/index.js +0 -7
  326. package/dist/core/weights/weights_container.d.ts +0 -19
  327. package/dist/core/weights/weights_container.js +0 -64
  328. package/dist/imports.d.ts +0 -2
  329. package/dist/imports.js +0 -7
  330. package/dist/memory/memory.d.ts +0 -26
  331. package/dist/memory/memory.js +0 -160
  332. package/dist/{core/task → task}/data_example.d.ts +1 -1
  333. package/dist/{core/task → task}/digest.d.ts +0 -0
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -1,7 +1,4 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.addDifferentialPrivacy = void 0;
4
- var _1 = require(".");
1
+ import * as tf from '@tensorflow/tfjs';
5
2
  /**
6
3
  * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
7
4
  * The previous round's weights are the last weights pulled from server/peers.
@@ -11,20 +8,19 @@ var _1 = require(".");
11
8
  * @param task the task
12
9
  * @returns the noised weights for the current round
13
10
  */
14
- function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
15
- var _a, _b;
16
- var noiseScale = (_a = task.trainingInformation) === null || _a === void 0 ? void 0 : _a.noiseScale;
17
- var clippingRadius = (_b = task.trainingInformation) === null || _b === void 0 ? void 0 : _b.clippingRadius;
18
- var weightsDiff = updatedWeights.sub(staleWeights);
19
- var newWeightsDiff;
11
+ export function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
12
+ const noiseScale = task.trainingInformation?.noiseScale;
13
+ const clippingRadius = task.trainingInformation?.clippingRadius;
14
+ const weightsDiff = updatedWeights.sub(staleWeights);
15
+ let newWeightsDiff;
20
16
  if (clippingRadius !== undefined) {
21
17
  // Frobenius norm
22
- var norm_1 = weightsDiff.frobeniusNorm();
23
- newWeightsDiff = weightsDiff.map(function (w) {
24
- var clipped = w.div(Math.max(1, norm_1 / clippingRadius));
18
+ const norm = weightsDiff.frobeniusNorm();
19
+ newWeightsDiff = weightsDiff.map((w) => {
20
+ const clipped = w.div(Math.max(1, norm / clippingRadius));
25
21
  if (noiseScale !== undefined) {
26
22
  // Add clipping and noise
27
- var noise = _1.tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
23
+ const noise = tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
28
24
  return clipped.add(noise);
29
25
  }
30
26
  else {
@@ -36,7 +32,7 @@ function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
36
32
  else {
37
33
  if (noiseScale !== undefined) {
38
34
  // Add noise without any clipping
39
- newWeightsDiff = weightsDiff.map(function (w) { return _1.tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)); });
35
+ newWeightsDiff = weightsDiff.map((w) => tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)));
40
36
  }
41
37
  else {
42
38
  return updatedWeights;
@@ -44,4 +40,3 @@ function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
44
40
  }
45
41
  return staleWeights.add(newWeightsDiff);
46
42
  }
47
- exports.addDifferentialPrivacy = addDifferentialPrivacy;
@@ -0,0 +1,2 @@
1
+ export * as model from './model.js';
2
+ export * as weights from './weights.js';
@@ -0,0 +1,2 @@
1
+ export * as model from './model.js';
2
+ export * as weights from './weights.js';
@@ -0,0 +1,5 @@
1
+ import type { Model } from '../index.js';
2
+ export type Encoded = Uint8Array;
3
+ export declare function isEncoded(raw: unknown): raw is Encoded;
4
+ export declare function encode(model: Model): Promise<Encoded>;
5
+ export declare function decode(encoded: unknown): Promise<Model>;
@@ -0,0 +1,67 @@
1
+ import msgpack from 'msgpack-lite';
2
+ import { models, serialization } from '../index.js';
3
+ const Type = {
4
+ TFJS: 0,
5
+ GPT: 1
6
+ };
7
+ export function isEncoded(raw) {
8
+ return raw instanceof Uint8Array;
9
+ }
10
+ export async function encode(model) {
11
+ if (model instanceof models.TFJS) {
12
+ const serialized = await model.serialize();
13
+ return msgpack.encode([Type.TFJS, serialized]);
14
+ }
15
+ if (model instanceof models.GPT) {
16
+ const { weights, config } = model.serialize();
17
+ const serializedWeights = await serialization.weights.encode(weights);
18
+ return msgpack.encode([Type.GPT, serializedWeights, config]);
19
+ }
20
+ throw new Error('unknown model type');
21
+ }
22
+ export async function decode(encoded) {
23
+ if (!isEncoded(encoded)) {
24
+ throw new Error("Invalid encoding, raw encoding isn't an instance of Uint8Array");
25
+ }
26
+ const raw = msgpack.decode(encoded);
27
+ if (!Array.isArray(raw) || raw.length < 2) {
28
+ throw new Error("invalid encoding, encoding isn't an array or doesn't contain enough values");
29
+ }
30
+ const type = raw[0];
31
+ if (typeof type !== 'number') {
32
+ throw new Error('invalid encoding, first encoding field should be the model type');
33
+ }
34
+ const rawModel = raw[1];
35
+ switch (type) {
36
+ case Type.TFJS:
37
+ if (raw.length !== 2) {
38
+ throw new Error('invalid encoding, TFJS model encoding should be an array of length 2');
39
+ }
40
+ // TODO totally unsafe casting
41
+ return await models.TFJS.deserialize(rawModel);
42
+ case Type.GPT: {
43
+ let config;
44
+ if (raw.length == 2) {
45
+ config = undefined;
46
+ }
47
+ else if (raw.length == 3) {
48
+ config = raw[2];
49
+ }
50
+ else {
51
+ throw new Error('invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3');
52
+ }
53
+ if (!Array.isArray(rawModel)) {
54
+ throw new Error('invalid encoding, gpt-tfjs model weights should be an array');
55
+ }
56
+ const arr = rawModel;
57
+ if (arr.some((r) => typeof r !== 'number')) {
58
+ throw new Error("invalid encoding, gpt-tfjs weights should be numbers");
59
+ }
60
+ const nums = arr;
61
+ const weights = serialization.weights.decode(nums);
62
+ return models.GPT.deserialize({ weights, config });
63
+ }
64
+ default:
65
+ throw new Error('invalid encoding, model type unrecognized');
66
+ }
67
+ }
@@ -1,5 +1,5 @@
1
- import { WeightsContainer } from '..';
2
- export declare type Encoded = number[];
1
+ import { WeightsContainer } from '../index.js';
2
+ export type Encoded = number[];
3
3
  export declare function isEncoded(raw: unknown): raw is Encoded;
4
4
  export declare function encode(weights: WeightsContainer): Promise<Encoded>;
5
5
  export declare function decode(encoded: Encoded): WeightsContainer;
@@ -0,0 +1,37 @@
1
+ import * as msgpack from 'msgpack-lite';
2
+ import * as tf from '@tensorflow/tfjs';
3
+ import { WeightsContainer } from '../index.js';
4
+ function isSerialized(raw) {
5
+ if (typeof raw !== 'object' || raw === null) {
6
+ return false;
7
+ }
8
+ const { shape, data } = raw;
9
+ if (!(Array.isArray(shape) && shape.every((e) => typeof e === 'number')) ||
10
+ !(Array.isArray(data) && data.every((e) => typeof e === 'number'))) {
11
+ return false;
12
+ }
13
+ const _ = {
14
+ shape: shape,
15
+ data: data,
16
+ };
17
+ return true;
18
+ }
19
+ export function isEncoded(raw) {
20
+ return Array.isArray(raw) && raw.every((e) => typeof e === 'number');
21
+ }
22
+ export async function encode(weights) {
23
+ const serialized = await Promise.all(weights.weights.map(async (t) => {
24
+ return {
25
+ shape: t.shape,
26
+ data: [...await t.data()]
27
+ };
28
+ }));
29
+ return [...msgpack.encode(serialized).values()];
30
+ }
31
+ export function decode(encoded) {
32
+ const raw = msgpack.decode(encoded);
33
+ if (!(Array.isArray(raw) && raw.every(isSerialized))) {
34
+ throw new Error('expected to decode an array of serialized weights');
35
+ }
36
+ return new WeightsContainer(raw.map((w) => tf.tensor(w.data, w.shape)));
37
+ }
@@ -0,0 +1,14 @@
1
+ export function isDataExample(raw) {
2
+ if (typeof raw !== 'object' || raw === null) {
3
+ return false;
4
+ }
5
+ const { columnName, columnData } = raw;
6
+ if (typeof columnName !== 'string' ||
7
+ (typeof columnData !== 'string' && typeof columnData !== 'number')) {
8
+ return false;
9
+ }
10
+ const repack = { columnName, columnData };
11
+ const _correct = repack;
12
+ const _total = repack;
13
+ return true;
14
+ }
@@ -0,0 +1,14 @@
1
+ export function isDigest(raw) {
2
+ if (typeof raw !== 'object' || raw === null) {
3
+ return false;
4
+ }
5
+ const { algorithm, value } = raw;
6
+ if (!(typeof algorithm === 'string' &&
7
+ typeof value === 'string')) {
8
+ return false;
9
+ }
10
+ const repack = { algorithm, value };
11
+ const _correct = repack;
12
+ const _total = repack;
13
+ return true;
14
+ }
@@ -1,6 +1,6 @@
1
- import { Summary } from './summary';
2
- import { DataExample } from './data_example';
3
- export declare function isDisplayInformation(raw: unknown): raw is DisplayInformation;
1
+ import { type Summary } from './summary.js';
2
+ import { type DataExample } from './data_example.js';
3
+ import { type LabelType } from './label_type.js';
4
4
  export interface DisplayInformation {
5
5
  taskTitle?: string;
6
6
  summary?: Summary;
@@ -12,4 +12,6 @@ export interface DisplayInformation {
12
12
  headers?: string[];
13
13
  dataExampleImage?: string;
14
14
  limitations?: string;
15
+ labelDisplay?: LabelType;
15
16
  }
17
+ export declare function isDisplayInformation(raw: unknown): raw is DisplayInformation;
@@ -0,0 +1,46 @@
1
+ import { isSummary } from './summary.js';
2
+ import { isDataExample } from './data_example.js';
3
+ import { isLabelType } from './label_type.js';
4
+ export function isDisplayInformation(raw) {
5
+ if (typeof raw !== 'object' || raw === null) {
6
+ return false;
7
+ }
8
+ const { dataExample, dataExampleImage, dataExampleText, dataFormatInformation, headers, labelDisplay, limitations, model, summary, taskTitle, tradeoffs } = raw;
9
+ if (typeof taskTitle !== 'string' ||
10
+ (dataExampleText !== undefined && typeof dataExampleText !== 'string') ||
11
+ (dataFormatInformation !== undefined && typeof dataFormatInformation !== 'string') ||
12
+ (tradeoffs !== undefined && typeof tradeoffs !== 'string') ||
13
+ (model !== undefined && typeof model !== 'string') ||
14
+ (dataExampleImage !== undefined && typeof dataExampleImage !== 'string') ||
15
+ (labelDisplay !== undefined && !isLabelType(labelDisplay)) ||
16
+ (limitations !== undefined && typeof limitations !== 'string')) {
17
+ return false;
18
+ }
19
+ if (summary !== undefined && !isSummary(summary)) {
20
+ return false;
21
+ }
22
+ if (dataExample !== undefined && !(Array.isArray(dataExample) &&
23
+ dataExample.every(isDataExample))) {
24
+ return false;
25
+ }
26
+ if (headers !== undefined && !(Array.isArray(headers) &&
27
+ headers.every((e) => typeof e === 'string'))) {
28
+ return false;
29
+ }
30
+ const repack = {
31
+ dataExample,
32
+ dataExampleImage,
33
+ dataExampleText,
34
+ dataFormatInformation,
35
+ headers,
36
+ labelDisplay,
37
+ limitations,
38
+ model,
39
+ summary,
40
+ taskTitle,
41
+ tradeoffs,
42
+ };
43
+ const _correct = repack;
44
+ const _total = repack;
45
+ return true;
46
+ }
@@ -0,0 +1,7 @@
1
+ export { isTask, type Task, isTaskID, type TaskID } from './task.js';
2
+ export { type TaskProvider } from './task_provider.js';
3
+ export { isDigest, type Digest } from './digest.js';
4
+ export { isDisplayInformation, type DisplayInformation } from './display_information.js';
5
+ export type { TrainingInformation } from './training_information.js';
6
+ export { pushTask, fetchTasks } from './task_handler.js';
7
+ export { LabelTypeEnum } from './label_type.js';
@@ -0,0 +1,5 @@
1
+ export { isTask, isTaskID } from './task.js';
2
+ export { isDigest } from './digest.js';
3
+ export { isDisplayInformation } from './display_information.js';
4
+ export { pushTask, fetchTasks } from './task_handler.js';
5
+ export { LabelTypeEnum } from './label_type.js';
@@ -0,0 +1,9 @@
1
+ export interface LabelType {
2
+ labelType: LabelTypeEnum;
3
+ mapBaseUrl?: string;
4
+ }
5
+ export declare enum LabelTypeEnum {
6
+ TEXT = 0,
7
+ POLYGON_MAP = 1
8
+ }
9
+ export declare function isLabelType(raw: unknown): raw is LabelType;
@@ -0,0 +1,28 @@
1
+ export var LabelTypeEnum;
2
+ (function (LabelTypeEnum) {
3
+ LabelTypeEnum[LabelTypeEnum["TEXT"] = 0] = "TEXT";
4
+ LabelTypeEnum[LabelTypeEnum["POLYGON_MAP"] = 1] = "POLYGON_MAP";
5
+ })(LabelTypeEnum || (LabelTypeEnum = {}));
6
+ function isLabelTypeEnum(raw) {
7
+ switch (raw) {
8
+ case LabelTypeEnum.TEXT: break;
9
+ case LabelTypeEnum.POLYGON_MAP: break;
10
+ default: return false;
11
+ }
12
+ const _ = raw;
13
+ return true;
14
+ }
15
+ export function isLabelType(raw) {
16
+ if (typeof raw !== 'object' || raw === null) {
17
+ return false;
18
+ }
19
+ const { labelType, mapBaseUrl } = raw;
20
+ if (!isLabelTypeEnum(labelType) ||
21
+ (mapBaseUrl !== undefined && typeof mapBaseUrl !== 'string')) {
22
+ return false;
23
+ }
24
+ const repack = { labelType, mapBaseUrl };
25
+ const _correct = repack;
26
+ const _total = repack;
27
+ return true;
28
+ }
@@ -0,0 +1,13 @@
1
+ export function isSummary(raw) {
2
+ if (typeof raw !== 'object' || raw === null) {
3
+ return false;
4
+ }
5
+ const { preview, overview } = raw;
6
+ if (!(typeof preview === 'string' && typeof overview === 'string')) {
7
+ return false;
8
+ }
9
+ const repack = { preview, overview };
10
+ const _correct = repack;
11
+ const _total = repack;
12
+ return true;
13
+ }
@@ -1,12 +1,12 @@
1
- import { DisplayInformation } from './display_information';
2
- import { TrainingInformation } from './training_information';
3
- import { Digest } from './digest';
4
- export declare type TaskID = string;
5
- export declare function isTaskID(obj: unknown): obj is TaskID;
6
- export declare function isTask(raw: unknown): raw is Task;
1
+ import { type DisplayInformation } from './display_information.js';
2
+ import { type TrainingInformation } from './training_information.js';
3
+ import { type Digest } from './digest.js';
4
+ export type TaskID = string;
7
5
  export interface Task {
8
- taskID: TaskID;
6
+ id: TaskID;
9
7
  digest?: Digest;
10
8
  displayInformation: DisplayInformation;
11
9
  trainingInformation: TrainingInformation;
12
10
  }
11
+ export declare function isTaskID(obj: unknown): obj is TaskID;
12
+ export declare function isTask(raw: unknown): raw is Task;
@@ -0,0 +1,22 @@
1
+ import { isDisplayInformation } from './display_information.js';
2
+ import { isTrainingInformation } from './training_information.js';
3
+ import { isDigest } from './digest.js';
4
+ export function isTaskID(obj) {
5
+ return typeof obj === 'string';
6
+ }
7
+ export function isTask(raw) {
8
+ if (typeof raw !== 'object' || raw === null) {
9
+ return false;
10
+ }
11
+ const { id, digest, displayInformation, trainingInformation } = raw;
12
+ if (!isTaskID(id) ||
13
+ (digest !== undefined && !isDigest(digest)) ||
14
+ !isDisplayInformation(displayInformation) ||
15
+ !isTrainingInformation(trainingInformation)) {
16
+ return false;
17
+ }
18
+ const repack = { id, digest, displayInformation, trainingInformation };
19
+ const _correct = repack;
20
+ const _total = repack;
21
+ return true;
22
+ }
@@ -0,0 +1,5 @@
1
+ import { Map } from 'immutable';
2
+ import type { Model } from '../index.js';
3
+ import type { Task, TaskID } from './task.js';
4
+ export declare function pushTask(url: URL, task: Task, model: Model): Promise<void>;
5
+ export declare function fetchTasks(url: URL): Promise<Map<TaskID, Task>>;
@@ -0,0 +1,20 @@
1
+ import axios from 'axios';
2
+ import { Map } from 'immutable';
3
+ import { serialization } from '../index.js';
4
+ import { isTask } from './task.js';
5
+ const TASK_ENDPOINT = 'tasks';
6
+ export async function pushTask(url, task, model) {
7
+ await axios.post(url.href + TASK_ENDPOINT, {
8
+ task,
9
+ model: await serialization.model.encode(model),
10
+ weights: await serialization.weights.encode(model.weights)
11
+ });
12
+ }
13
+ export async function fetchTasks(url) {
14
+ const response = await axios.get(new URL(TASK_ENDPOINT, url).href);
15
+ const tasks = response.data;
16
+ if (!(Array.isArray(tasks) && tasks.every(isTask))) {
17
+ throw new Error('invalid tasks response');
18
+ }
19
+ return Map(tasks.map((t) => [t.id, t]));
20
+ }
@@ -0,0 +1,5 @@
1
+ import type { Model, Task } from '../index.js';
2
+ export interface TaskProvider {
3
+ getTask: () => Task;
4
+ getModel: () => Promise<Model>;
5
+ }
@@ -0,0 +1 @@
1
+ export {};
@@ -1,6 +1,6 @@
1
- import { Preprocessing } from '../dataset/data/preprocessing';
2
- import { ModelCompileData } from './model_compile_data';
3
- export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
1
+ import type { AggregatorChoice } from '../aggregator/get.js';
2
+ import type { Preprocessing } from '../dataset/data/preprocessing/index.js';
3
+ import { PreTrainedTokenizer } from '@xenova/transformers';
4
4
  export interface TrainingInformation {
5
5
  modelID: string;
6
6
  epochs: number;
@@ -8,21 +8,20 @@ export interface TrainingInformation {
8
8
  validationSplit: number;
9
9
  batchSize: number;
10
10
  preprocessingFunctions?: Preprocessing[];
11
- modelCompileData: ModelCompileData;
12
- dataType: string;
11
+ dataType: 'image' | 'tabular' | 'text';
13
12
  inputColumns?: string[];
14
13
  outputColumns?: string[];
15
14
  IMAGE_H?: number;
16
15
  IMAGE_W?: number;
17
- modelURL?: string;
18
16
  LABEL_LIST?: string[];
19
- learningRate?: number;
20
- scheme: string;
17
+ scheme: 'decentralized' | 'federated' | 'local';
21
18
  noiseScale?: number;
22
19
  clippingRadius?: number;
23
20
  decentralizedSecure?: boolean;
24
- byzantineRobustAggregator?: boolean;
25
- tauPercentile?: number;
26
21
  maxShareValue?: number;
27
22
  minimumReadyPeers?: number;
23
+ aggregator?: AggregatorChoice;
24
+ tokenizer?: string | PreTrainedTokenizer;
25
+ maxSequenceLength?: number;
28
26
  }
27
+ export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
@@ -0,0 +1,88 @@
1
+ import { PreTrainedTokenizer } from '@xenova/transformers';
2
+ function isStringArray(raw) {
3
+ if (!Array.isArray(raw)) {
4
+ return false;
5
+ }
6
+ const arr = raw; // isArray is unsafely guarding with any[]
7
+ return arr.every((e) => typeof e === 'string');
8
+ }
9
+ export function isTrainingInformation(raw) {
10
+ if (typeof raw !== 'object' || raw === null) {
11
+ return false;
12
+ }
13
+ const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, clippingRadius, dataType, decentralizedSecure, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, noiseScale, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, } = raw;
14
+ if (typeof dataType !== 'string' ||
15
+ typeof modelID !== 'string' ||
16
+ typeof epochs !== 'number' ||
17
+ typeof batchSize !== 'number' ||
18
+ typeof roundDuration !== 'number' ||
19
+ typeof validationSplit !== 'number' ||
20
+ (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
21
+ (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
22
+ (aggregator !== undefined && typeof aggregator !== 'number') ||
23
+ (clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
24
+ (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
25
+ (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
26
+ (minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
27
+ (noiseScale !== undefined && typeof noiseScale !== 'number') ||
28
+ (IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
29
+ (IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
30
+ (LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
31
+ (inputColumns !== undefined && !isStringArray(inputColumns)) ||
32
+ (outputColumns !== undefined && !isStringArray(outputColumns)) ||
33
+ (preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
34
+ return false;
35
+ }
36
+ switch (dataType) {
37
+ case 'image': break;
38
+ case 'tabular': break;
39
+ case 'text': break;
40
+ default: return false;
41
+ }
42
+ // interdepences on data type
43
+ if (dataType === 'image') {
44
+ if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
45
+ return false;
46
+ }
47
+ }
48
+ else if (dataType in ['text', 'tabular']) {
49
+ if (!(Array.isArray(inputColumns) && inputColumns.every((e) => typeof e === 'string'))) {
50
+ return false;
51
+ }
52
+ if (!(Array.isArray(outputColumns) && outputColumns.every((e) => typeof e === 'string'))) {
53
+ return false;
54
+ }
55
+ }
56
+ switch (scheme) {
57
+ case 'decentralized': break;
58
+ case 'federated': break;
59
+ case 'local': break;
60
+ default: return false;
61
+ }
62
+ const repack = {
63
+ IMAGE_W,
64
+ IMAGE_H,
65
+ LABEL_LIST,
66
+ aggregator,
67
+ batchSize,
68
+ clippingRadius,
69
+ dataType,
70
+ decentralizedSecure,
71
+ epochs,
72
+ inputColumns,
73
+ maxShareValue,
74
+ minimumReadyPeers,
75
+ modelID,
76
+ noiseScale,
77
+ outputColumns,
78
+ preprocessingFunctions,
79
+ roundDuration,
80
+ scheme,
81
+ validationSplit,
82
+ tokenizer,
83
+ maxSequenceLength
84
+ };
85
+ const _correct = repack;
86
+ const _total = repack;
87
+ return true;
88
+ }
@@ -0,0 +1,40 @@
1
+ import type { data, Logger, Memory, Task, TrainingInformation } from '../index.js';
2
+ import { client as clients } from '../index.js';
3
+ import type { Aggregator } from '../aggregator/index.js';
4
+ import type { RoundLogs } from './trainer/trainer.js';
5
+ export interface DiscoOptions {
6
+ client?: clients.Client;
7
+ aggregator?: Aggregator;
8
+ url?: string | URL;
9
+ scheme?: TrainingInformation['scheme'];
10
+ logger?: Logger;
11
+ memory?: Memory;
12
+ }
13
+ /**
14
+ * Top-level class handling distributed training from a client's perspective. It is meant to be
15
+ * a convenient object providing a reduced yet complete API that wraps model training,
16
+ * communication with nodes, logs and model memory.
17
+ */
18
+ export declare class Disco {
19
+ readonly task: Task;
20
+ readonly logger: Logger;
21
+ readonly memory: Memory;
22
+ private readonly client;
23
+ private readonly trainer;
24
+ constructor(task: Task, options: DiscoOptions);
25
+ /**
26
+ * Starts a training instance for the Disco object's task on the provided data tuple.
27
+ * @param dataTuple The data tuple
28
+ */
29
+ fit(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
30
+ participants: number;
31
+ }>;
32
+ /**
33
+ * Stops the ongoing training instance without disconnecting the client.
34
+ */
35
+ pause(): Promise<void>;
36
+ /**
37
+ * Completely stops the ongoing training instance.
38
+ */
39
+ close(): Promise<void>;
40
+ }