@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.js +14 -0
  135. package/dist/{core/task → task}/display_information.d.ts +5 -3
  136. package/dist/task/display_information.js +46 -0
  137. package/dist/task/index.d.ts +7 -0
  138. package/dist/task/index.js +5 -0
  139. package/dist/task/label_type.d.ts +9 -0
  140. package/dist/task/label_type.js +28 -0
  141. package/dist/task/summary.js +13 -0
  142. package/dist/{core/task → task}/task.d.ts +7 -7
  143. package/dist/task/task.js +22 -0
  144. package/dist/task/task_handler.d.ts +5 -0
  145. package/dist/task/task_handler.js +20 -0
  146. package/dist/task/task_provider.d.ts +5 -0
  147. package/dist/task/task_provider.js +1 -0
  148. package/dist/{core/task → task}/training_information.d.ts +9 -10
  149. package/dist/task/training_information.js +88 -0
  150. package/dist/training/disco.d.ts +40 -0
  151. package/dist/training/disco.js +107 -0
  152. package/dist/training/index.d.ts +2 -0
  153. package/dist/training/index.js +1 -0
  154. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  155. package/dist/training/trainer/distributed_trainer.js +36 -0
  156. package/dist/training/trainer/local_trainer.d.ts +12 -0
  157. package/dist/training/trainer/local_trainer.js +19 -0
  158. package/dist/training/trainer/trainer.d.ts +33 -0
  159. package/dist/training/trainer/trainer.js +52 -0
  160. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  161. package/dist/training/trainer/trainer_builder.js +43 -0
  162. package/dist/types.d.ts +8 -0
  163. package/dist/types.js +1 -0
  164. package/dist/utils/event_emitter.d.ts +40 -0
  165. package/dist/utils/event_emitter.js +57 -0
  166. package/dist/validation/index.d.ts +1 -0
  167. package/dist/validation/index.js +1 -0
  168. package/dist/validation/validator.d.ts +28 -0
  169. package/dist/validation/validator.js +132 -0
  170. package/dist/weights/aggregation.d.ts +21 -0
  171. package/dist/weights/aggregation.js +44 -0
  172. package/dist/weights/index.d.ts +2 -0
  173. package/dist/weights/index.js +2 -0
  174. package/dist/weights/weights_container.d.ts +68 -0
  175. package/dist/weights/weights_container.js +96 -0
  176. package/package.json +24 -15
  177. package/README.md +0 -53
  178. package/dist/core/async_buffer.d.ts +0 -41
  179. package/dist/core/async_buffer.js +0 -97
  180. package/dist/core/async_informant.d.ts +0 -20
  181. package/dist/core/async_informant.js +0 -69
  182. package/dist/core/client/base.d.ts +0 -33
  183. package/dist/core/client/base.js +0 -35
  184. package/dist/core/client/decentralized/base.d.ts +0 -32
  185. package/dist/core/client/decentralized/base.js +0 -212
  186. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  187. package/dist/core/client/decentralized/clear_text.js +0 -96
  188. package/dist/core/client/decentralized/index.d.ts +0 -4
  189. package/dist/core/client/decentralized/index.js +0 -9
  190. package/dist/core/client/decentralized/messages.d.ts +0 -41
  191. package/dist/core/client/decentralized/messages.js +0 -54
  192. package/dist/core/client/decentralized/peer.d.ts +0 -26
  193. package/dist/core/client/decentralized/peer.js +0 -210
  194. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  195. package/dist/core/client/decentralized/peer_pool.js +0 -92
  196. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  197. package/dist/core/client/decentralized/sec_agg.js +0 -190
  198. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  199. package/dist/core/client/decentralized/secret_shares.js +0 -39
  200. package/dist/core/client/decentralized/types.d.ts +0 -2
  201. package/dist/core/client/decentralized/types.js +0 -7
  202. package/dist/core/client/event_connection.d.ts +0 -37
  203. package/dist/core/client/event_connection.js +0 -158
  204. package/dist/core/client/federated/client.d.ts +0 -37
  205. package/dist/core/client/federated/client.js +0 -273
  206. package/dist/core/client/federated/index.d.ts +0 -2
  207. package/dist/core/client/federated/index.js +0 -7
  208. package/dist/core/client/federated/messages.d.ts +0 -38
  209. package/dist/core/client/federated/messages.js +0 -25
  210. package/dist/core/client/index.d.ts +0 -5
  211. package/dist/core/client/index.js +0 -11
  212. package/dist/core/client/local.d.ts +0 -8
  213. package/dist/core/client/local.js +0 -36
  214. package/dist/core/client/messages.d.ts +0 -28
  215. package/dist/core/client/messages.js +0 -33
  216. package/dist/core/client/utils.d.ts +0 -2
  217. package/dist/core/client/utils.js +0 -19
  218. package/dist/core/dataset/data/data.d.ts +0 -11
  219. package/dist/core/dataset/data/data.js +0 -20
  220. package/dist/core/dataset/data/data_split.d.ts +0 -5
  221. package/dist/core/dataset/data/data_split.js +0 -2
  222. package/dist/core/dataset/data/image_data.d.ts +0 -8
  223. package/dist/core/dataset/data/image_data.js +0 -64
  224. package/dist/core/dataset/data/index.d.ts +0 -5
  225. package/dist/core/dataset/data/index.js +0 -11
  226. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  227. package/dist/core/dataset/data/preprocessing.js +0 -33
  228. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  229. package/dist/core/dataset/data/tabular_data.js +0 -40
  230. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  231. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  232. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  233. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  234. package/dist/core/dataset/data_loader/index.js +0 -9
  235. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  236. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  237. package/dist/core/dataset/dataset.d.ts +0 -2
  238. package/dist/core/dataset/dataset.js +0 -2
  239. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  240. package/dist/core/dataset/dataset_builder.js +0 -96
  241. package/dist/core/dataset/index.d.ts +0 -4
  242. package/dist/core/dataset/index.js +0 -14
  243. package/dist/core/default_tasks/cifar10.d.ts +0 -2
  244. package/dist/core/default_tasks/geotags.d.ts +0 -2
  245. package/dist/core/default_tasks/geotags.js +0 -69
  246. package/dist/core/default_tasks/index.d.ts +0 -6
  247. package/dist/core/default_tasks/index.js +0 -15
  248. package/dist/core/default_tasks/lus_covid.d.ts +0 -2
  249. package/dist/core/default_tasks/lus_covid.js +0 -96
  250. package/dist/core/default_tasks/mnist.d.ts +0 -2
  251. package/dist/core/default_tasks/simple_face.d.ts +0 -2
  252. package/dist/core/default_tasks/titanic.d.ts +0 -2
  253. package/dist/core/index.d.ts +0 -18
  254. package/dist/core/index.js +0 -39
  255. package/dist/core/informant/graph_informant.js +0 -23
  256. package/dist/core/informant/index.d.ts +0 -3
  257. package/dist/core/informant/index.js +0 -9
  258. package/dist/core/informant/training_informant/base.d.ts +0 -31
  259. package/dist/core/informant/training_informant/base.js +0 -83
  260. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  261. package/dist/core/informant/training_informant/decentralized.js +0 -22
  262. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  263. package/dist/core/informant/training_informant/federated.js +0 -32
  264. package/dist/core/informant/training_informant/index.d.ts +0 -4
  265. package/dist/core/informant/training_informant/index.js +0 -11
  266. package/dist/core/informant/training_informant/local.d.ts +0 -6
  267. package/dist/core/informant/training_informant/local.js +0 -20
  268. package/dist/core/logging/console_logger.js +0 -33
  269. package/dist/core/logging/index.d.ts +0 -3
  270. package/dist/core/logging/index.js +0 -9
  271. package/dist/core/logging/logger.js +0 -9
  272. package/dist/core/logging/trainer_logger.d.ts +0 -24
  273. package/dist/core/logging/trainer_logger.js +0 -59
  274. package/dist/core/memory/base.d.ts +0 -22
  275. package/dist/core/memory/base.js +0 -9
  276. package/dist/core/memory/empty.d.ts +0 -14
  277. package/dist/core/memory/empty.js +0 -75
  278. package/dist/core/memory/index.d.ts +0 -3
  279. package/dist/core/memory/index.js +0 -9
  280. package/dist/core/memory/model_type.d.ts +0 -4
  281. package/dist/core/memory/model_type.js +0 -9
  282. package/dist/core/serialization/index.d.ts +0 -2
  283. package/dist/core/serialization/index.js +0 -6
  284. package/dist/core/serialization/model.d.ts +0 -5
  285. package/dist/core/serialization/model.js +0 -55
  286. package/dist/core/serialization/weights.js +0 -64
  287. package/dist/core/task/data_example.js +0 -24
  288. package/dist/core/task/digest.js +0 -18
  289. package/dist/core/task/display_information.js +0 -49
  290. package/dist/core/task/index.d.ts +0 -6
  291. package/dist/core/task/index.js +0 -15
  292. package/dist/core/task/model_compile_data.d.ts +0 -6
  293. package/dist/core/task/model_compile_data.js +0 -22
  294. package/dist/core/task/summary.js +0 -19
  295. package/dist/core/task/task.js +0 -35
  296. package/dist/core/task/task_handler.d.ts +0 -5
  297. package/dist/core/task/task_handler.js +0 -53
  298. package/dist/core/task/task_provider.d.ts +0 -6
  299. package/dist/core/task/task_provider.js +0 -13
  300. package/dist/core/task/training_information.js +0 -66
  301. package/dist/core/training/disco.d.ts +0 -23
  302. package/dist/core/training/disco.js +0 -130
  303. package/dist/core/training/index.d.ts +0 -2
  304. package/dist/core/training/index.js +0 -7
  305. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  306. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  307. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  308. package/dist/core/training/trainer/local_trainer.js +0 -34
  309. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  310. package/dist/core/training/trainer/round_tracker.js +0 -47
  311. package/dist/core/training/trainer/trainer.d.ts +0 -65
  312. package/dist/core/training/trainer/trainer.js +0 -160
  313. package/dist/core/training/trainer/trainer_builder.js +0 -95
  314. package/dist/core/training/training_schemes.d.ts +0 -5
  315. package/dist/core/training/training_schemes.js +0 -10
  316. package/dist/core/types.d.ts +0 -4
  317. package/dist/core/types.js +0 -2
  318. package/dist/core/validation/index.d.ts +0 -1
  319. package/dist/core/validation/index.js +0 -5
  320. package/dist/core/validation/validator.d.ts +0 -17
  321. package/dist/core/validation/validator.js +0 -104
  322. package/dist/core/weights/aggregation.d.ts +0 -7
  323. package/dist/core/weights/aggregation.js +0 -72
  324. package/dist/core/weights/index.d.ts +0 -2
  325. package/dist/core/weights/index.js +0 -7
  326. package/dist/core/weights/weights_container.d.ts +0 -19
  327. package/dist/core/weights/weights_container.js +0 -64
  328. package/dist/imports.d.ts +0 -2
  329. package/dist/imports.js +0 -7
  330. package/dist/memory/memory.d.ts +0 -26
  331. package/dist/memory/memory.js +0 -160
  332. package/dist/{core/task → task}/data_example.d.ts +1 -1
  333. package/dist/{core/task → task}/digest.d.ts +0 -0
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -1,26 +1,101 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.WebImageLoader = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("../..");
6
- var WebImageLoader = /** @class */ (function (_super) {
7
- (0, tslib_1.__extends)(WebImageLoader, _super);
8
- function WebImageLoader() {
9
- return _super !== null && _super.apply(this, arguments) || this;
1
+ import { Range } from 'immutable';
2
+ import * as tf from '@tensorflow/tfjs';
3
+ import { ImageData } from '../data/index.js';
4
+ import { DataLoader } from '../data_loader/index.js';
5
+ /**
6
+ * Image data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
7
+ * @epfml/discojs-web and @epfml/discojs-node.
8
+ * Load labels and correctly match them with their respective images, with the following constraints:
9
+ * 1. Images are given as 1 image/1 file;
10
+ * 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels.
11
+ */
12
+ export class ImageLoader extends DataLoader {
13
+ task;
14
+ constructor(task) {
15
+ super();
16
+ this.task = task;
10
17
  }
11
- WebImageLoader.prototype.readImageFrom = function (source) {
12
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
13
- var _a, _b;
14
- return (0, tslib_1.__generator)(this, function (_c) {
15
- switch (_c.label) {
16
- case 0:
17
- _b = (_a = __1.tf.browser).fromPixels;
18
- return [4 /*yield*/, createImageBitmap(source)];
19
- case 1: return [2 /*return*/, _b.apply(_a, [_c.sent()])];
18
+ async load(image, config) {
19
+ let tensorContainer;
20
+ if (config?.labels === undefined) {
21
+ tensorContainer = await this.readImageFrom(image, config?.channels);
22
+ }
23
+ else {
24
+ tensorContainer = {
25
+ xs: await this.readImageFrom(image, config?.channels),
26
+ ys: config.labels[0]
27
+ };
28
+ }
29
+ return tf.data.array([tensorContainer]);
30
+ }
31
+ async buildDataset(images, labels, indices, config) {
32
+ // Can't use arrow function for generator and need access to 'this'
33
+ // eslint-disable-next-line
34
+ const self = this;
35
+ async function* dataGenerator() {
36
+ const withLabels = config?.labels !== undefined;
37
+ let index = 0;
38
+ while (index < indices.length) {
39
+ const sample = await self.readImageFrom(images[indices[index]], config?.channels);
40
+ const label = withLabels ? labels[indices[index]] : undefined;
41
+ const value = withLabels ? { xs: sample, ys: label } : sample;
42
+ index++;
43
+ yield value;
44
+ }
45
+ }
46
+ // @ts-expect-error: For some reasons typescript refuses async generator but tensorflow do work with them
47
+ const dataset = tf.data.generator(dataGenerator);
48
+ return await ImageData.init(dataset, this.task, indices.length);
49
+ }
50
+ async loadAll(images, config) {
51
+ let labels = [];
52
+ const indices = Range(0, images.length).toArray();
53
+ if (config?.labels !== undefined) {
54
+ const labelList = this.task.trainingInformation?.LABEL_LIST;
55
+ if (labelList === undefined || !Array.isArray(labelList)) {
56
+ throw new Error('LABEL_LIST should be specified in the task training information');
57
+ }
58
+ const numberOfClasses = labelList.length;
59
+ // Map label strings to integer
60
+ const label_to_int = new Map(labelList.map((label_name, idx) => [label_name, idx]));
61
+ if (label_to_int.size != numberOfClasses) {
62
+ throw new Error("Input labels aren't matching the task LABEL_LIST");
63
+ }
64
+ labels = config.labels.map(label_name => {
65
+ const label_int = label_to_int.get(label_name);
66
+ if (label_int === undefined) {
67
+ throw new Error(`Found input label ${label_name} not specified in task LABEL_LIST`);
20
68
  }
69
+ return label_int;
21
70
  });
22
- });
23
- };
24
- return WebImageLoader;
25
- }(__1.data.ImageLoader));
26
- exports.WebImageLoader = WebImageLoader;
71
+ labels = await tf.oneHot(tf.tensor1d(labels, 'int32'), numberOfClasses).array();
72
+ }
73
+ if (config?.shuffle === undefined || config?.shuffle) {
74
+ this.shuffle(indices);
75
+ }
76
+ if (config?.validationSplit === undefined || config?.validationSplit === 0) {
77
+ const dataset = await this.buildDataset(images, labels, indices, config);
78
+ return {
79
+ train: dataset,
80
+ validation: undefined
81
+ };
82
+ }
83
+ const trainSize = Math.floor(images.length * (1 - config.validationSplit));
84
+ const trainIndices = indices.slice(0, trainSize);
85
+ const valIndices = indices.slice(trainSize);
86
+ const trainDataset = await this.buildDataset(images, labels, trainIndices, config);
87
+ const valDataset = await this.buildDataset(images, labels, valIndices, config);
88
+ return {
89
+ train: trainDataset,
90
+ validation: valDataset
91
+ };
92
+ }
93
+ shuffle(array) {
94
+ for (let i = 0; i < array.length; i++) {
95
+ const j = Math.floor(Math.random() * i);
96
+ const swap = array[i];
97
+ array[i] = array[j];
98
+ array[j] = swap;
99
+ }
100
+ }
101
+ }
@@ -1,2 +1,5 @@
1
- export { WebImageLoader } from './image_loader';
2
- export { WebTabularLoader } from './tabular_loader';
1
+ export type { DataConfig } from './data_loader.js';
2
+ export { DataLoader } from './data_loader.js';
3
+ export { ImageLoader } from './image_loader.js';
4
+ export { TabularLoader } from './tabular_loader.js';
5
+ export { TextLoader } from './text_loader.js';
@@ -1,7 +1,4 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.WebTabularLoader = exports.WebImageLoader = void 0;
4
- var image_loader_1 = require("./image_loader");
5
- Object.defineProperty(exports, "WebImageLoader", { enumerable: true, get: function () { return image_loader_1.WebImageLoader; } });
6
- var tabular_loader_1 = require("./tabular_loader");
7
- Object.defineProperty(exports, "WebTabularLoader", { enumerable: true, get: function () { return tabular_loader_1.WebTabularLoader; } });
1
+ export { DataLoader } from './data_loader.js';
2
+ export { ImageLoader } from './image_loader.js';
3
+ export { TabularLoader } from './tabular_loader.js';
4
+ export { TextLoader } from './text_loader.js';
@@ -1,4 +1,35 @@
1
- import { tf, data } from '../..';
2
- export declare class WebTabularLoader extends data.TabularLoader<File> {
3
- loadTabularDatasetFrom(source: File, csvConfig: Record<string, unknown>): tf.data.CSVDataset;
1
+ import type { Task } from '../../index.js';
2
+ import type { Dataset, DataSplit } from '../index.js';
3
+ import type { DataConfig } from './index.js';
4
+ import { DataLoader } from './index.js';
5
+ /**
6
+ * Tabular data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
7
+ * @epfml/discojs-web and @epfml/discojs-node. Loads data from files whose entries are line-separated and consist of
8
+ * character-separated features and label(s). Such files typically have the .csv extension.
9
+ */
10
+ export declare abstract class TabularLoader<Source> extends DataLoader<Source> {
11
+ private readonly task;
12
+ readonly delimiter: string;
13
+ constructor(task: Task, delimiter?: string);
14
+ /**
15
+ * Creates a CSV dataset object based off the given source.
16
+ * @param source File object, URL string or local file system path.
17
+ * @param csvConfig Object expected by TF.js to create a CSVDataset.
18
+ * @returns The CSVDataset object built upon the given source.
19
+ */
20
+ abstract loadDatasetFrom(source: Source, csvConfig: Record<string, unknown>): Promise<Dataset>;
21
+ /**
22
+ * Expects delimiter-separated tabular data made of N columns. The data may be
23
+ * potentially split among several sources. Every source should contain N-1
24
+ * feature columns and 1 single label column.
25
+ * @param source List of File objects, URLs or file system paths.
26
+ * @param config
27
+ * @returns A TF.js dataset built upon read tabular data stored in the given sources.
28
+ */
29
+ load(source: Source, config?: DataConfig): Promise<Dataset>;
30
+ /**
31
+ * Creates the CSV datasets based off the given sources, then fuses them into a single CSV
32
+ * dataset.
33
+ */
34
+ loadAll(sources: Source[], config: DataConfig): Promise<DataSplit>;
4
35
  }
@@ -1,16 +1,76 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.WebTabularLoader = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("../..");
6
- var WebTabularLoader = /** @class */ (function (_super) {
7
- (0, tslib_1.__extends)(WebTabularLoader, _super);
8
- function WebTabularLoader() {
9
- return _super !== null && _super.apply(this, arguments) || this;
1
+ import { List, Map, Set } from 'immutable';
2
+ import { TabularData } from '../index.js';
3
+ import { DataLoader } from './index.js';
4
+ // Window size from which the dataset shuffling will sample
5
+ const BUFFER_SIZE = 1000;
6
+ /**
7
+ * Tabular data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
8
+ * @epfml/discojs-web and @epfml/discojs-node. Loads data from files whose entries are line-separated and consist of
9
+ * character-separated features and label(s). Such files typically have the .csv extension.
10
+ */
11
+ export class TabularLoader extends DataLoader {
12
+ task;
13
+ delimiter;
14
+ constructor(task, delimiter = ',') {
15
+ super();
16
+ this.task = task;
17
+ this.delimiter = delimiter;
10
18
  }
11
- WebTabularLoader.prototype.loadTabularDatasetFrom = function (source, csvConfig) {
12
- return new __1.tf.data.CSVDataset(new __1.tf.data.FileDataSource(source), csvConfig);
13
- };
14
- return WebTabularLoader;
15
- }(__1.data.TabularLoader));
16
- exports.WebTabularLoader = WebTabularLoader;
19
+ /**
20
+ * Expects delimiter-separated tabular data made of N columns. The data may be
21
+ * potentially split among several sources. Every source should contain N-1
22
+ * feature columns and 1 single label column.
23
+ * @param source List of File objects, URLs or file system paths.
24
+ * @param config
25
+ * @returns A TF.js dataset built upon read tabular data stored in the given sources.
26
+ */
27
+ async load(source, config) {
28
+ /**
29
+ * Prepare the CSV config object based off the given features and labels.
30
+ * If labels is empty, then the returned dataset is comprised of samples only.
31
+ * Otherwise, each entry is of the form `{ xs, ys }` with `xs` as features and `ys`
32
+ * as labels.
33
+ */
34
+ if (config?.features === undefined) {
35
+ // TODO @s314cy
36
+ throw new Error('Not implemented');
37
+ }
38
+ const columnConfigs = Map(Set(config.features).map((feature) => [feature, { required: false, isLabel: false }])).merge(Set(config.labels).map((label) => [label, { required: true, isLabel: true }]));
39
+ const csvConfig = {
40
+ hasHeader: true,
41
+ columnConfigs: columnConfigs.toObject(),
42
+ configuredColumnsOnly: true,
43
+ delimiter: this.delimiter
44
+ };
45
+ const dataset = (await this.loadDatasetFrom(source, csvConfig)).map((t) => {
46
+ if (typeof t === 'object') {
47
+ if (('xs' in t) && ('ys' in t)) {
48
+ const { xs, ys } = t;
49
+ return {
50
+ xs: Object.values(xs),
51
+ ys: Object.values(ys)
52
+ };
53
+ }
54
+ else {
55
+ return t;
56
+ }
57
+ }
58
+ throw new TypeError('Expected TensorContainerObject');
59
+ });
60
+ return (config?.shuffle === undefined || config?.shuffle) ? dataset.shuffle(BUFFER_SIZE) : dataset;
61
+ }
62
+ /**
63
+ * Creates the CSV datasets based off the given sources, then fuses them into a single CSV
64
+ * dataset.
65
+ */
66
+ async loadAll(sources, config) {
67
+ const datasets = await Promise.all(sources.map(async (source) => await this.load(source, { ...config, shuffle: false })));
68
+ let dataset = List(datasets).reduce((acc, dataset) => acc.concatenate(dataset));
69
+ dataset = config?.shuffle === true ? dataset.shuffle(BUFFER_SIZE) : dataset;
70
+ const data = await TabularData.init(dataset, this.task);
71
+ // TODO: Implement validation split for tabular data (tricky due to streaming)
72
+ return {
73
+ train: data
74
+ };
75
+ }
76
+ }
@@ -0,0 +1,14 @@
1
+ import type { Task } from '../../index.js';
2
+ import type { DataSplit, Dataset } from '../index.js';
3
+ import { DataLoader, DataConfig } from './index.js';
4
+ /**
5
+ * Text data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
6
+ * @epfml/discojs-web and @epfml/discojs-node.
7
+ */
8
+ export declare abstract class TextLoader<S> extends DataLoader<S> {
9
+ private readonly task;
10
+ constructor(task: Task);
11
+ abstract loadDatasetFrom(source: S): Promise<Dataset>;
12
+ load(source: S, config?: DataConfig): Promise<Dataset>;
13
+ loadAll(sources: S[], config?: DataConfig): Promise<DataSplit>;
14
+ }
@@ -0,0 +1,25 @@
1
+ import { TextData } from '../index.js';
2
+ import { DataLoader } from './index.js';
3
+ /**
4
+ * Text data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
5
+ * @epfml/discojs-web and @epfml/discojs-node.
6
+ */
7
+ export class TextLoader extends DataLoader {
8
+ task;
9
+ constructor(task) {
10
+ super();
11
+ this.task = task;
12
+ }
13
+ async load(source, config) {
14
+ const dataset = await this.loadDatasetFrom(source);
15
+ // 1st arg: Stream shuffling buffer size
16
+ return (config?.shuffle === undefined || config?.shuffle) ? dataset.shuffle(1000, undefined, true) : dataset;
17
+ }
18
+ async loadAll(sources, config) {
19
+ const concatenated = (await Promise.all(sources.map(async (src) => await this.load(src, config))))
20
+ .reduce((acc, dataset) => acc.concatenate(dataset));
21
+ return {
22
+ train: await TextData.init(concatenated, this.task)
23
+ };
24
+ }
25
+ }
@@ -0,0 +1,5 @@
1
+ import type tf from '@tensorflow/tfjs';
2
+ /**
3
+ * Convenient type for the common dataset type used in TF.js.
4
+ */
5
+ export type Dataset = tf.data.Dataset<tf.TensorContainer>;
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,60 @@
1
+ import type { Task } from '../index.js';
2
+ import type { DataSplit } from './data/index.js';
3
+ import type { DataConfig, DataLoader } from './data_loader/data_loader.js';
4
+ /**
5
+ * Incrementally builds a dataset from the provided file sources. The sources may
6
+ * either be file blobs or regular file system paths.
7
+ */
8
+ export declare class DatasetBuilder<Source> {
9
+ /**
10
+ * The data loader used to load the data contained in the provided files.
11
+ */
12
+ private readonly dataLoader;
13
+ /**
14
+ * The task for which the dataset should be built.
15
+ */
16
+ readonly task: Task;
17
+ /**
18
+ * The buffer of unlabelled file sources.
19
+ */
20
+ private _sources;
21
+ /**
22
+ * The buffer of labelled file sources.
23
+ */
24
+ private labelledSources;
25
+ /**
26
+ * Whether a dataset was already produced.
27
+ */
28
+ private _built;
29
+ constructor(
30
+ /**
31
+ * The data loader used to load the data contained in the provided files.
32
+ */
33
+ dataLoader: DataLoader<Source>,
34
+ /**
35
+ * The task for which the dataset should be built.
36
+ */
37
+ task: Task);
38
+ /**
39
+ * Adds the given file sources to the builder's buffer. Sources may be provided a label in the case
40
+ * of supervised learning.
41
+ * @param sources The array of file sources
42
+ * @param label The file sources label
43
+ */
44
+ addFiles(sources: Source[], label?: string): void;
45
+ /**
46
+ * Clears the file sources buffers. If a label is provided, only the file sources
47
+ * corresponding to the given label will be removed.
48
+ * @param label The file sources label
49
+ */
50
+ clearFiles(label?: string): void;
51
+ private resetBuiltState;
52
+ private getLabels;
53
+ build(config?: DataConfig): Promise<DataSplit>;
54
+ /**
55
+ * Whether the dataset builder has already been consumed to produce a dataset.
56
+ */
57
+ get built(): boolean;
58
+ get size(): number;
59
+ get sources(): Source[];
60
+ }
@@ -0,0 +1,142 @@
1
+ import { Map } from 'immutable';
2
+ /**
3
+ * Incrementally builds a dataset from the provided file sources. The sources may
4
+ * either be file blobs or regular file system paths.
5
+ */
6
+ export class DatasetBuilder {
7
+ dataLoader;
8
+ task;
9
+ /**
10
+ * The buffer of unlabelled file sources.
11
+ */
12
+ _sources;
13
+ /**
14
+ * The buffer of labelled file sources.
15
+ */
16
+ labelledSources;
17
+ /**
18
+ * Whether a dataset was already produced.
19
+ */
20
+ // TODO useless, responsibility on callers
21
+ _built;
22
+ constructor(
23
+ /**
24
+ * The data loader used to load the data contained in the provided files.
25
+ */
26
+ dataLoader,
27
+ /**
28
+ * The task for which the dataset should be built.
29
+ */
30
+ task) {
31
+ this.dataLoader = dataLoader;
32
+ this.task = task;
33
+ this._sources = [];
34
+ this.labelledSources = Map();
35
+ this._built = false;
36
+ }
37
+ /**
38
+ * Adds the given file sources to the builder's buffer. Sources may be provided a label in the case
39
+ * of supervised learning.
40
+ * @param sources The array of file sources
41
+ * @param label The file sources label
42
+ */
43
+ addFiles(sources, label) {
44
+ if (this.built) {
45
+ this.resetBuiltState();
46
+ }
47
+ if (label === undefined) {
48
+ this._sources = this._sources.concat(sources);
49
+ }
50
+ else {
51
+ const currentSources = this.labelledSources.get(label);
52
+ if (currentSources === undefined) {
53
+ this.labelledSources = this.labelledSources.set(label, sources);
54
+ }
55
+ else {
56
+ this.labelledSources = this.labelledSources.set(label, currentSources.concat(sources));
57
+ }
58
+ }
59
+ }
60
+ /**
61
+ * Clears the file sources buffers. If a label is provided, only the file sources
62
+ * corresponding to the given label will be removed.
63
+ * @param label The file sources label
64
+ */
65
+ clearFiles(label) {
66
+ if (this.built) {
67
+ this.resetBuiltState();
68
+ }
69
+ if (label === undefined) {
70
+ this._sources = [];
71
+ }
72
+ else {
73
+ this.labelledSources = this.labelledSources.delete(label);
74
+ }
75
+ }
76
+ // If files are added or removed, then this should be called since the latest
77
+ // version of the dataset_builder has not yet been built.
78
+ resetBuiltState() {
79
+ this._built = false;
80
+ }
81
+ getLabels() {
82
+ // We need to duplicate the labels as we need one for each source.
83
+ // Say for label A we have sources [img1, img2, img3], then we
84
+ // need labels [A, A, A].
85
+ let labels = [];
86
+ this.labelledSources.forEach((sources, label) => {
87
+ const sourcesLabels = Array.from({ length: sources.length }, (_) => label);
88
+ labels = labels.concat(sourcesLabels);
89
+ });
90
+ return labels.flat();
91
+ }
92
+ async build(config) {
93
+ // Require that at least one source collection is non-empty, but not both
94
+ if ((this._sources.length > 0) === (this.labelledSources.size > 0)) {
95
+ throw new Error('Please provide dataset input files');
96
+ }
97
+ let dataTuple;
98
+ if (this._sources.length > 0) {
99
+ let defaultConfig = {};
100
+ if (config?.inference === true) {
101
+ // Inferring model, no labels needed
102
+ defaultConfig = {
103
+ features: this.task.trainingInformation.inputColumns,
104
+ shuffle: false
105
+ };
106
+ }
107
+ else {
108
+ // Labels are contained in the given sources
109
+ defaultConfig = {
110
+ features: this.task.trainingInformation.inputColumns,
111
+ labels: this.task.trainingInformation.outputColumns,
112
+ shuffle: false
113
+ };
114
+ }
115
+ dataTuple = await this.dataLoader.loadAll(this._sources, { ...defaultConfig, ...config });
116
+ }
117
+ else {
118
+ // Labels are inferred from the file selection boxes
119
+ const defaultConfig = {
120
+ labels: this.getLabels(),
121
+ shuffle: false
122
+ };
123
+ const sources = this.labelledSources.valueSeq().toArray().flat();
124
+ dataTuple = await this.dataLoader.loadAll(sources, { ...defaultConfig, ...config });
125
+ }
126
+ // TODO @s314cy: Support .csv labels for image datasets (supervised training or testing)
127
+ this._built = true;
128
+ return dataTuple;
129
+ }
130
+ /**
131
+ * Whether the dataset builder has already been consumed to produce a dataset.
132
+ */
133
+ get built() {
134
+ return this._built;
135
+ }
136
+ get size() {
137
+ return Math.max(this._sources.length, this.labelledSources.size);
138
+ }
139
+ get sources() {
140
+ return this._sources.length > 0 ? this._sources : this.labelledSources.valueSeq().toArray().flat();
141
+ }
142
+ }
@@ -0,0 +1,5 @@
1
+ export type { Dataset } from './dataset.js';
2
+ export { DatasetBuilder } from './dataset_builder.js';
3
+ export { ImageLoader, TabularLoader, DataLoader, TextLoader } from './data_loader/index.js';
4
+ export type { DataSplit } from './data/index.js';
5
+ export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './data/index.js';
@@ -0,0 +1,3 @@
1
+ export { DatasetBuilder } from './dataset_builder.js';
2
+ export { ImageLoader, TabularLoader, DataLoader, TextLoader } from './data_loader/index.js';
3
+ export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './data/index.js';
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../../index.js';
2
+ export declare const cifar10: TaskProvider;
@@ -1,12 +1,10 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.cifar10 = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("..");
6
- exports.cifar10 = {
7
- getTask: function () {
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { data, models } from '../../index.js';
3
+ import baseModel from './model.js';
4
+ export const cifar10 = {
5
+ getTask() {
8
6
  return {
9
- taskID: 'cifar10',
7
+ id: 'cifar10',
10
8
  displayInformation: {
11
9
  taskTitle: 'CIFAR10',
12
10
  summary: {
@@ -25,17 +23,12 @@ exports.cifar10 = {
25
23
  roundDuration: 10,
26
24
  validationSplit: 0.2,
27
25
  batchSize: 10,
28
- modelCompileData: {
29
- optimizer: 'sgd',
30
- loss: 'categoricalCrossentropy',
31
- metrics: ['accuracy']
32
- },
33
26
  dataType: 'image',
34
- IMAGE_H: 32,
35
- IMAGE_W: 32,
36
- preprocessingFunctions: [],
27
+ preprocessingFunctions: [data.ImagePreprocessing.Resize],
28
+ IMAGE_H: 224,
29
+ IMAGE_W: 224,
37
30
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
38
- scheme: 'Decentralized',
31
+ scheme: 'decentralized',
39
32
  noiseScale: undefined,
40
33
  clippingRadius: 20,
41
34
  decentralizedSecure: true,
@@ -44,25 +37,24 @@ exports.cifar10 = {
44
37
  }
45
38
  };
46
39
  },
47
- getModel: function () {
48
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
49
- var mobilenet, x, predictions;
50
- return (0, tslib_1.__generator)(this, function (_a) {
51
- switch (_a.label) {
52
- case 0: return [4 /*yield*/, __1.tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json')];
53
- case 1:
54
- mobilenet = _a.sent();
55
- x = mobilenet.getLayer('global_average_pooling2d_1');
56
- predictions = __1.tf.layers
57
- .dense({ units: 10, activation: 'softmax', name: 'denseModified' })
58
- .apply(x.output);
59
- return [2 /*return*/, __1.tf.model({
60
- inputs: mobilenet.input,
61
- outputs: predictions,
62
- name: 'modelModified'
63
- })];
64
- }
65
- });
40
+ async getModel() {
41
+ const mobilenet = await tf.loadLayersModel({
42
+ load: async () => Promise.resolve(baseModel),
43
+ });
44
+ const x = mobilenet.getLayer('global_average_pooling2d_1');
45
+ const predictions = tf.layers
46
+ .dense({ units: 10, activation: 'softmax', name: 'denseModified' })
47
+ .apply(x.output);
48
+ const model = tf.model({
49
+ inputs: mobilenet.input,
50
+ outputs: predictions,
51
+ name: 'modelModified'
52
+ });
53
+ model.compile({
54
+ optimizer: 'sgd',
55
+ loss: 'categoricalCrossentropy',
56
+ metrics: ['accuracy']
66
57
  });
58
+ return new models.TFJS(model);
67
59
  }
68
60
  };