ins-pricing 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +74 -56
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +52 -50
- ins_pricing/cli/BayesOpt_incremental.py +832 -898
- ins_pricing/cli/Explain_Run.py +31 -23
- ins_pricing/cli/Explain_entry.py +532 -579
- ins_pricing/cli/Pricing_Run.py +31 -23
- ins_pricing/cli/bayesopt_entry_runner.py +1440 -1438
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +375 -375
- ins_pricing/cli/utils/import_resolver.py +382 -365
- ins_pricing/cli/utils/notebook_utils.py +340 -340
- ins_pricing/cli/watchdog_run.py +209 -201
- ins_pricing/frontend/README.md +573 -419
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/config_builder.py +1 -0
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/README.md +67 -0
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/bayesopt/README.md +59 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -550
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -962
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +116 -83
- ins_pricing/utils/device.py +255 -255
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +55 -35
- ins_pricing-0.5.0.dist-info/RECORD +131 -0
- ins_pricing/CHANGELOG.md +0 -272
- ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
- ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
- ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
- ins_pricing/docs/modelling/README.md +0 -34
- ins_pricing/frontend/QUICKSTART.md +0 -152
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
- ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
- ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.4.dist-info/RECORD +0 -137
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
- {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,785 +1,788 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import hashlib
|
|
4
|
-
import os
|
|
5
|
-
import time
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
from typing import Any, Dict, Optional, Tuple
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
import pandas as pd
|
|
11
|
-
import torch
|
|
12
|
-
import torch.distributed as dist
|
|
13
|
-
import torch.nn as nn
|
|
14
|
-
from sklearn.neighbors import NearestNeighbors
|
|
15
|
-
from torch.cuda.amp import autocast, GradScaler
|
|
16
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
17
|
-
from torch.nn.utils import clip_grad_norm_
|
|
18
|
-
|
|
19
|
-
from
|
|
20
|
-
from
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
pynndescent
|
|
41
|
-
_PYNN_AVAILABLE =
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
h =
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
self.
|
|
124
|
-
self.
|
|
125
|
-
self.
|
|
126
|
-
self.
|
|
127
|
-
self.
|
|
128
|
-
self.
|
|
129
|
-
self.
|
|
130
|
-
self.
|
|
131
|
-
self.
|
|
132
|
-
self.
|
|
133
|
-
self.
|
|
134
|
-
|
|
135
|
-
self.
|
|
136
|
-
self.
|
|
137
|
-
|
|
138
|
-
self.
|
|
139
|
-
self.
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
self.
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
self.
|
|
146
|
-
self.
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
if
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
self.
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
self.
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
self.device = torch.device(
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
self.
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
self.
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
self.
|
|
274
|
-
|
|
275
|
-
self.
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
hasher.
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
"
|
|
306
|
-
"
|
|
307
|
-
"
|
|
308
|
-
|
|
309
|
-
),
|
|
310
|
-
"
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
"
|
|
319
|
-
"
|
|
320
|
-
"
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
X_df
|
|
329
|
-
|
|
330
|
-
X_df.
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
cols =
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
adj =
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
if
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
self._adj_cache_tensor =
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
self._adj_cache_tensor =
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
self._adj_cache_meta =
|
|
598
|
-
self._adj_cache_key =
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
X_val_tensor
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
self.
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
scaler.
|
|
662
|
-
scaler.
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
def
|
|
731
|
-
self.
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
return
|
|
751
|
-
|
|
752
|
-
def
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import torch
|
|
12
|
+
import torch.distributed as dist
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
from sklearn.neighbors import NearestNeighbors
|
|
15
|
+
from torch.cuda.amp import autocast, GradScaler
|
|
16
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
17
|
+
from torch.nn.utils import clip_grad_norm_
|
|
18
|
+
|
|
19
|
+
from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
|
|
20
|
+
from ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin import TorchTrainerMixin
|
|
21
|
+
from ins_pricing.utils import EPS
|
|
22
|
+
from ins_pricing.utils.io import IOUtils
|
|
23
|
+
from ins_pricing.utils.losses import (
|
|
24
|
+
infer_loss_name_from_model_name,
|
|
25
|
+
normalize_loss_name,
|
|
26
|
+
resolve_tweedie_power,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from torch_geometric.nn import knn_graph
|
|
31
|
+
from torch_geometric.utils import add_self_loops, to_undirected
|
|
32
|
+
_PYG_AVAILABLE = True
|
|
33
|
+
except Exception:
|
|
34
|
+
knn_graph = None # type: ignore
|
|
35
|
+
add_self_loops = None # type: ignore
|
|
36
|
+
to_undirected = None # type: ignore
|
|
37
|
+
_PYG_AVAILABLE = False
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
import pynndescent
|
|
41
|
+
_PYNN_AVAILABLE = True
|
|
42
|
+
except Exception:
|
|
43
|
+
pynndescent = None # type: ignore
|
|
44
|
+
_PYNN_AVAILABLE = False
|
|
45
|
+
|
|
46
|
+
_GNN_MPS_WARNED = False
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# =============================================================================
|
|
50
|
+
# Simplified GNN implementation.
|
|
51
|
+
# =============================================================================
|
|
52
|
+
|
|
53
|
+
def _adj_mm(adj: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
"""Matrix multiply that supports sparse or dense adjacency."""
|
|
55
|
+
if adj.is_sparse:
|
|
56
|
+
return torch.sparse.mm(adj, x)
|
|
57
|
+
return adj.matmul(x)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SimpleGraphLayer(nn.Module):
|
|
61
|
+
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.linear = nn.Linear(in_dim, out_dim)
|
|
64
|
+
self.activation = nn.ReLU(inplace=True)
|
|
65
|
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
|
66
|
+
|
|
67
|
+
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
|
|
68
|
+
# Message passing with normalized sparse adjacency: A_hat * X * W.
|
|
69
|
+
h = _adj_mm(adj, x)
|
|
70
|
+
h = self.linear(h)
|
|
71
|
+
h = self.activation(h)
|
|
72
|
+
return self.dropout(h)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class SimpleGNN(nn.Module):
|
|
76
|
+
def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
|
|
77
|
+
dropout: float = 0.1, task_type: str = 'regression'):
|
|
78
|
+
super().__init__()
|
|
79
|
+
layers = []
|
|
80
|
+
dim_in = input_dim
|
|
81
|
+
for _ in range(max(1, num_layers)):
|
|
82
|
+
layers.append(SimpleGraphLayer(
|
|
83
|
+
dim_in, hidden_dim, dropout=dropout))
|
|
84
|
+
dim_in = hidden_dim
|
|
85
|
+
self.layers = nn.ModuleList(layers)
|
|
86
|
+
self.output = nn.Linear(hidden_dim, 1)
|
|
87
|
+
if task_type == 'classification':
|
|
88
|
+
self.output_act = nn.Identity()
|
|
89
|
+
else:
|
|
90
|
+
self.output_act = nn.Softplus()
|
|
91
|
+
self.task_type = task_type
|
|
92
|
+
# Keep adjacency as a buffer for DataParallel copies.
|
|
93
|
+
self.register_buffer("adj_buffer", torch.empty(0))
|
|
94
|
+
|
|
95
|
+
def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
96
|
+
adj_used = adj if adj is not None else getattr(
|
|
97
|
+
self, "adj_buffer", None)
|
|
98
|
+
if adj_used is None or adj_used.numel() == 0:
|
|
99
|
+
raise RuntimeError("Adjacency is not set for GNN forward.")
|
|
100
|
+
h = x
|
|
101
|
+
for layer in self.layers:
|
|
102
|
+
h = layer(h, adj_used)
|
|
103
|
+
h = _adj_mm(adj_used, h)
|
|
104
|
+
out = self.output(h)
|
|
105
|
+
return self.output_act(out)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
109
|
+
def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
|
|
110
|
+
num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
|
|
111
|
+
learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
|
|
112
|
+
task_type: str = 'regression', tweedie_power: float = 1.5,
|
|
113
|
+
weight_decay: float = 0.0,
|
|
114
|
+
use_data_parallel: bool = False, use_ddp: bool = False,
|
|
115
|
+
use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
|
|
116
|
+
graph_cache_path: Optional[str] = None,
|
|
117
|
+
max_gpu_knn_nodes: Optional[int] = None,
|
|
118
|
+
knn_gpu_mem_ratio: float = 0.9,
|
|
119
|
+
knn_gpu_mem_overhead: float = 2.0,
|
|
120
|
+
knn_cpu_jobs: Optional[int] = -1,
|
|
121
|
+
loss_name: Optional[str] = None) -> None:
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.model_nme = model_nme
|
|
124
|
+
self.input_dim = input_dim
|
|
125
|
+
self.hidden_dim = hidden_dim
|
|
126
|
+
self.num_layers = num_layers
|
|
127
|
+
self.k_neighbors = max(1, k_neighbors)
|
|
128
|
+
self.dropout = dropout
|
|
129
|
+
self.learning_rate = learning_rate
|
|
130
|
+
self.weight_decay = weight_decay
|
|
131
|
+
self.epochs = epochs
|
|
132
|
+
self.patience = patience
|
|
133
|
+
self.task_type = task_type
|
|
134
|
+
self.use_approx_knn = use_approx_knn
|
|
135
|
+
self.approx_knn_threshold = approx_knn_threshold
|
|
136
|
+
self.graph_cache_path = Path(
|
|
137
|
+
graph_cache_path) if graph_cache_path else None
|
|
138
|
+
self.max_gpu_knn_nodes = max_gpu_knn_nodes
|
|
139
|
+
self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
|
|
140
|
+
self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
|
|
141
|
+
self.knn_cpu_jobs = knn_cpu_jobs
|
|
142
|
+
self.mps_dense_max_nodes = int(
|
|
143
|
+
os.environ.get("BAYESOPT_GNN_MPS_DENSE_MAX_NODES", "5000")
|
|
144
|
+
)
|
|
145
|
+
self._knn_warning_emitted = False
|
|
146
|
+
self._mps_fallback_triggered = False
|
|
147
|
+
self._adj_cache_meta: Optional[Dict[str, Any]] = None
|
|
148
|
+
self._adj_cache_key: Optional[Tuple[Any, ...]] = None
|
|
149
|
+
self._adj_cache_tensor: Optional[torch.Tensor] = None
|
|
150
|
+
|
|
151
|
+
resolved_loss = normalize_loss_name(loss_name, self.task_type)
|
|
152
|
+
if self.task_type == 'classification':
|
|
153
|
+
self.loss_name = "logloss"
|
|
154
|
+
self.tw_power = None
|
|
155
|
+
else:
|
|
156
|
+
if resolved_loss == "auto":
|
|
157
|
+
resolved_loss = infer_loss_name_from_model_name(self.model_nme)
|
|
158
|
+
self.loss_name = resolved_loss
|
|
159
|
+
if self.loss_name == "tweedie":
|
|
160
|
+
self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
|
|
161
|
+
else:
|
|
162
|
+
self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
|
|
163
|
+
|
|
164
|
+
self.ddp_enabled = False
|
|
165
|
+
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
166
|
+
self.data_parallel_enabled = False
|
|
167
|
+
self._ddp_disabled = False
|
|
168
|
+
|
|
169
|
+
if use_ddp:
|
|
170
|
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
171
|
+
if world_size > 1:
|
|
172
|
+
print(
|
|
173
|
+
"[GNN] DDP training is not supported; falling back to single process.",
|
|
174
|
+
flush=True,
|
|
175
|
+
)
|
|
176
|
+
self._ddp_disabled = True
|
|
177
|
+
use_ddp = False
|
|
178
|
+
|
|
179
|
+
# DDP only works with CUDA; fall back to single process if init fails.
|
|
180
|
+
if use_ddp and torch.cuda.is_available():
|
|
181
|
+
ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
|
|
182
|
+
if ddp_ok:
|
|
183
|
+
self.ddp_enabled = True
|
|
184
|
+
self.local_rank = local_rank
|
|
185
|
+
self.device = torch.device(f'cuda:{local_rank}')
|
|
186
|
+
else:
|
|
187
|
+
self.device = torch.device('cuda')
|
|
188
|
+
elif torch.cuda.is_available():
|
|
189
|
+
if self._ddp_disabled:
|
|
190
|
+
self.device = torch.device(f'cuda:{self.local_rank}')
|
|
191
|
+
else:
|
|
192
|
+
self.device = torch.device('cuda')
|
|
193
|
+
elif torch.backends.mps.is_available():
|
|
194
|
+
self.device = torch.device('mps')
|
|
195
|
+
global _GNN_MPS_WARNED
|
|
196
|
+
if not _GNN_MPS_WARNED:
|
|
197
|
+
print(
|
|
198
|
+
"[GNN] Using MPS backend; will fall back to CPU on unsupported ops.",
|
|
199
|
+
flush=True,
|
|
200
|
+
)
|
|
201
|
+
_GNN_MPS_WARNED = True
|
|
202
|
+
else:
|
|
203
|
+
self.device = torch.device('cpu')
|
|
204
|
+
self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
|
|
205
|
+
|
|
206
|
+
self.gnn = SimpleGNN(
|
|
207
|
+
input_dim=self.input_dim,
|
|
208
|
+
hidden_dim=self.hidden_dim,
|
|
209
|
+
num_layers=self.num_layers,
|
|
210
|
+
dropout=self.dropout,
|
|
211
|
+
task_type=self.task_type
|
|
212
|
+
).to(self.device)
|
|
213
|
+
|
|
214
|
+
# DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
|
|
215
|
+
if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
|
|
216
|
+
self.data_parallel_enabled = True
|
|
217
|
+
self.gnn = nn.DataParallel(
|
|
218
|
+
self.gnn, device_ids=list(range(torch.cuda.device_count())))
|
|
219
|
+
self.device = torch.device('cuda')
|
|
220
|
+
|
|
221
|
+
if self.ddp_enabled:
|
|
222
|
+
self.gnn = DDP(
|
|
223
|
+
self.gnn,
|
|
224
|
+
device_ids=[self.local_rank],
|
|
225
|
+
output_device=self.local_rank,
|
|
226
|
+
find_unused_parameters=False
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
@staticmethod
|
|
230
|
+
def _validate_vector(arr, name: str, n_rows: int) -> None:
|
|
231
|
+
if arr is None:
|
|
232
|
+
return
|
|
233
|
+
if isinstance(arr, pd.DataFrame):
|
|
234
|
+
if arr.shape[1] != 1:
|
|
235
|
+
raise ValueError(f"{name} must be 1d (single column).")
|
|
236
|
+
length = len(arr)
|
|
237
|
+
else:
|
|
238
|
+
arr_np = np.asarray(arr)
|
|
239
|
+
if arr_np.ndim == 0:
|
|
240
|
+
raise ValueError(f"{name} must be 1d.")
|
|
241
|
+
if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
|
|
242
|
+
raise ValueError(f"{name} must be 1d or Nx1.")
|
|
243
|
+
length = arr_np.shape[0]
|
|
244
|
+
if length != n_rows:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f"{name} length {length} does not match X length {n_rows}."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def _unwrap_gnn(self) -> nn.Module:
|
|
250
|
+
if isinstance(self.gnn, (DDP, nn.DataParallel)):
|
|
251
|
+
return self.gnn.module
|
|
252
|
+
return self.gnn
|
|
253
|
+
|
|
254
|
+
def _set_adj_buffer(self, adj: torch.Tensor) -> None:
|
|
255
|
+
base = self._unwrap_gnn()
|
|
256
|
+
if hasattr(base, "adj_buffer"):
|
|
257
|
+
base.adj_buffer = adj
|
|
258
|
+
else:
|
|
259
|
+
base.register_buffer("adj_buffer", adj)
|
|
260
|
+
|
|
261
|
+
@staticmethod
|
|
262
|
+
def _is_mps_unsupported_error(exc: BaseException) -> bool:
|
|
263
|
+
msg = str(exc).lower()
|
|
264
|
+
if "mps" not in msg:
|
|
265
|
+
return False
|
|
266
|
+
if any(token in msg for token in ("not supported", "not implemented", "does not support", "unimplemented", "out of memory")):
|
|
267
|
+
return True
|
|
268
|
+
return "sparse" in msg
|
|
269
|
+
|
|
270
|
+
def _fallback_to_cpu(self, reason: str) -> None:
|
|
271
|
+
if self.device.type != "mps" or self._mps_fallback_triggered:
|
|
272
|
+
return
|
|
273
|
+
self._mps_fallback_triggered = True
|
|
274
|
+
print(f"[GNN] MPS op unsupported ({reason}); falling back to CPU.", flush=True)
|
|
275
|
+
self.device = torch.device("cpu")
|
|
276
|
+
self.use_pyg_knn = False
|
|
277
|
+
self.data_parallel_enabled = False
|
|
278
|
+
self.ddp_enabled = False
|
|
279
|
+
base = self._unwrap_gnn()
|
|
280
|
+
try:
|
|
281
|
+
base = base.to(self.device)
|
|
282
|
+
except Exception:
|
|
283
|
+
pass
|
|
284
|
+
self.gnn = base
|
|
285
|
+
self.invalidate_graph_cache()
|
|
286
|
+
|
|
287
|
+
def _run_with_mps_fallback(self, fn, *args, **kwargs):
|
|
288
|
+
try:
|
|
289
|
+
return fn(*args, **kwargs)
|
|
290
|
+
except (RuntimeError, NotImplementedError) as exc:
|
|
291
|
+
if self.device.type == "mps" and self._is_mps_unsupported_error(exc):
|
|
292
|
+
self._fallback_to_cpu(str(exc))
|
|
293
|
+
return fn(*args, **kwargs)
|
|
294
|
+
raise
|
|
295
|
+
|
|
296
|
+
def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
|
|
297
|
+
row_hash = pd.util.hash_pandas_object(X_df, index=False).values
|
|
298
|
+
idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
|
|
299
|
+
col_sig = ",".join(map(str, X_df.columns))
|
|
300
|
+
hasher = hashlib.sha256()
|
|
301
|
+
hasher.update(row_hash.tobytes())
|
|
302
|
+
hasher.update(idx_hash.tobytes())
|
|
303
|
+
hasher.update(col_sig.encode("utf-8", errors="ignore"))
|
|
304
|
+
knn_config = {
|
|
305
|
+
"k_neighbors": int(self.k_neighbors),
|
|
306
|
+
"use_approx_knn": bool(self.use_approx_knn),
|
|
307
|
+
"approx_knn_threshold": int(self.approx_knn_threshold),
|
|
308
|
+
"use_pyg_knn": bool(self.use_pyg_knn),
|
|
309
|
+
"pynndescent_available": bool(_PYNN_AVAILABLE),
|
|
310
|
+
"max_gpu_knn_nodes": (
|
|
311
|
+
None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
|
|
312
|
+
),
|
|
313
|
+
"knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
|
|
314
|
+
"knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
|
|
315
|
+
}
|
|
316
|
+
adj_format = "dense" if self.device.type == "mps" else "sparse"
|
|
317
|
+
return {
|
|
318
|
+
"n_samples": int(X_df.shape[0]),
|
|
319
|
+
"n_features": int(X_df.shape[1]),
|
|
320
|
+
"hash": hasher.hexdigest(),
|
|
321
|
+
"knn_config": knn_config,
|
|
322
|
+
"adj_format": adj_format,
|
|
323
|
+
"device_type": self.device.type,
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
|
|
327
|
+
return (
|
|
328
|
+
id(X_df),
|
|
329
|
+
id(getattr(X_df, "_mgr", None)),
|
|
330
|
+
id(X_df.index),
|
|
331
|
+
X_df.shape,
|
|
332
|
+
tuple(map(str, X_df.columns)),
|
|
333
|
+
X_df.attrs.get("graph_cache_key"),
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
def invalidate_graph_cache(self) -> None:
|
|
337
|
+
self._adj_cache_meta = None
|
|
338
|
+
self._adj_cache_key = None
|
|
339
|
+
self._adj_cache_tensor = None
|
|
340
|
+
|
|
341
|
+
def _load_cached_adj(self,
|
|
342
|
+
X_df: pd.DataFrame,
|
|
343
|
+
meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
|
|
344
|
+
if self.graph_cache_path and self.graph_cache_path.exists():
|
|
345
|
+
if meta_expected is None:
|
|
346
|
+
meta_expected = self._graph_cache_meta(X_df)
|
|
347
|
+
try:
|
|
348
|
+
payload = torch.load(self.graph_cache_path, map_location="cpu")
|
|
349
|
+
except Exception as exc:
|
|
350
|
+
print(
|
|
351
|
+
f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
|
|
352
|
+
return None
|
|
353
|
+
if isinstance(payload, dict) and "adj" in payload:
|
|
354
|
+
meta_cached = payload.get("meta")
|
|
355
|
+
if meta_cached == meta_expected:
|
|
356
|
+
adj = payload["adj"]
|
|
357
|
+
if self.device.type == "mps" and getattr(adj, "is_sparse", False):
|
|
358
|
+
print(
|
|
359
|
+
f"[GNN] Cached sparse graph incompatible with MPS; rebuilding: {self.graph_cache_path}"
|
|
360
|
+
)
|
|
361
|
+
return None
|
|
362
|
+
return adj.to(self.device)
|
|
363
|
+
print(
|
|
364
|
+
f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
|
|
365
|
+
return None
|
|
366
|
+
if isinstance(payload, torch.Tensor):
|
|
367
|
+
print(
|
|
368
|
+
f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
|
|
369
|
+
return None
|
|
370
|
+
print(
|
|
371
|
+
f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
|
|
372
|
+
return None
|
|
373
|
+
|
|
374
|
+
def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
|
|
375
|
+
n_samples = X_np.shape[0]
|
|
376
|
+
k = min(self.k_neighbors, max(1, n_samples - 1))
|
|
377
|
+
n_neighbors = min(k + 1, n_samples)
|
|
378
|
+
use_approx = (self.use_approx_knn or n_samples >=
|
|
379
|
+
self.approx_knn_threshold) and _PYNN_AVAILABLE
|
|
380
|
+
indices = None
|
|
381
|
+
if use_approx:
|
|
382
|
+
try:
|
|
383
|
+
nn_index = pynndescent.NNDescent(
|
|
384
|
+
X_np,
|
|
385
|
+
n_neighbors=n_neighbors,
|
|
386
|
+
random_state=0
|
|
387
|
+
)
|
|
388
|
+
indices, _ = nn_index.neighbor_graph
|
|
389
|
+
except Exception as exc:
|
|
390
|
+
print(
|
|
391
|
+
f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
|
|
392
|
+
use_approx = False
|
|
393
|
+
|
|
394
|
+
if indices is None:
|
|
395
|
+
nbrs = NearestNeighbors(
|
|
396
|
+
n_neighbors=n_neighbors,
|
|
397
|
+
algorithm="auto",
|
|
398
|
+
n_jobs=self.knn_cpu_jobs,
|
|
399
|
+
)
|
|
400
|
+
nbrs.fit(X_np)
|
|
401
|
+
_, indices = nbrs.kneighbors(X_np)
|
|
402
|
+
|
|
403
|
+
indices = np.asarray(indices)
|
|
404
|
+
rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
|
|
405
|
+
np.int64, copy=False)
|
|
406
|
+
cols = indices.reshape(-1).astype(np.int64, copy=False)
|
|
407
|
+
mask = rows != cols
|
|
408
|
+
rows = rows[mask]
|
|
409
|
+
cols = cols[mask]
|
|
410
|
+
rows_base = rows
|
|
411
|
+
cols_base = cols
|
|
412
|
+
self_loops = np.arange(n_samples, dtype=np.int64)
|
|
413
|
+
rows = np.concatenate([rows_base, cols_base, self_loops])
|
|
414
|
+
cols = np.concatenate([cols_base, rows_base, self_loops])
|
|
415
|
+
|
|
416
|
+
edge_index_np = np.stack([rows, cols], axis=0)
|
|
417
|
+
edge_index = torch.as_tensor(edge_index_np, device=self.device)
|
|
418
|
+
return edge_index
|
|
419
|
+
|
|
420
|
+
def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
|
|
421
|
+
if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
|
|
422
|
+
# Defensive: check use_pyg_knn before calling.
|
|
423
|
+
raise RuntimeError(
|
|
424
|
+
"GPU graph builder requested but PyG is unavailable.")
|
|
425
|
+
|
|
426
|
+
n_samples = X_tensor.size(0)
|
|
427
|
+
k = min(self.k_neighbors, max(1, n_samples - 1))
|
|
428
|
+
|
|
429
|
+
# knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
|
|
430
|
+
edge_index = knn_graph(
|
|
431
|
+
X_tensor,
|
|
432
|
+
k=k,
|
|
433
|
+
loop=False
|
|
434
|
+
)
|
|
435
|
+
edge_index = to_undirected(edge_index, num_nodes=n_samples)
|
|
436
|
+
edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
|
|
437
|
+
return edge_index
|
|
438
|
+
|
|
439
|
+
def _log_knn_fallback(self, reason: str) -> None:
|
|
440
|
+
if self._knn_warning_emitted:
|
|
441
|
+
return
|
|
442
|
+
if (not self.ddp_enabled) or self.local_rank == 0:
|
|
443
|
+
print(f"[GNN] Falling back to CPU kNN builder: {reason}")
|
|
444
|
+
self._knn_warning_emitted = True
|
|
445
|
+
|
|
446
|
+
def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
|
|
447
|
+
if not self.use_pyg_knn:
|
|
448
|
+
return False
|
|
449
|
+
|
|
450
|
+
reason = None
|
|
451
|
+
if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
|
|
452
|
+
reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
|
|
453
|
+
elif self.device.type == 'cuda' and torch.cuda.is_available():
|
|
454
|
+
try:
|
|
455
|
+
device_index = self.device.index
|
|
456
|
+
if device_index is None:
|
|
457
|
+
device_index = torch.cuda.current_device()
|
|
458
|
+
free_mem, total_mem = torch.cuda.mem_get_info(device_index)
|
|
459
|
+
feature_bytes = X_tensor.element_size() * X_tensor.nelement()
|
|
460
|
+
required = int(feature_bytes * self.knn_gpu_mem_overhead)
|
|
461
|
+
budget = int(free_mem * self.knn_gpu_mem_ratio)
|
|
462
|
+
if required > budget:
|
|
463
|
+
required_gb = required / (1024 ** 3)
|
|
464
|
+
budget_gb = budget / (1024 ** 3)
|
|
465
|
+
reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
|
|
466
|
+
f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
|
|
467
|
+
except Exception:
|
|
468
|
+
# On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
|
|
469
|
+
reason = None
|
|
470
|
+
|
|
471
|
+
if reason:
|
|
472
|
+
self._log_knn_fallback(reason)
|
|
473
|
+
return False
|
|
474
|
+
return True
|
|
475
|
+
|
|
476
|
+
def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
|
477
|
+
if self.device.type == "mps":
|
|
478
|
+
return self._normalized_adj_dense(edge_index, num_nodes)
|
|
479
|
+
return self._normalized_adj_sparse(edge_index, num_nodes)
|
|
480
|
+
|
|
481
|
+
def _normalized_adj_sparse(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
|
482
|
+
values = torch.ones(edge_index.shape[1], device=self.device)
|
|
483
|
+
adj = torch.sparse_coo_tensor(
|
|
484
|
+
edge_index.to(self.device), values, (num_nodes, num_nodes))
|
|
485
|
+
adj = adj.coalesce()
|
|
486
|
+
|
|
487
|
+
deg = torch.sparse.sum(adj, dim=1).to_dense()
|
|
488
|
+
deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
|
|
489
|
+
row, col = adj.indices()
|
|
490
|
+
norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
|
|
491
|
+
adj_norm = torch.sparse_coo_tensor(
|
|
492
|
+
adj.indices(), norm_values, size=adj.shape)
|
|
493
|
+
return adj_norm
|
|
494
|
+
|
|
495
|
+
def _normalized_adj_dense(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
|
496
|
+
if self.mps_dense_max_nodes <= 0 or num_nodes > self.mps_dense_max_nodes:
|
|
497
|
+
raise RuntimeError(
|
|
498
|
+
f"MPS dense adjacency not supported for {num_nodes} nodes; "
|
|
499
|
+
f"max={self.mps_dense_max_nodes}. Falling back to CPU."
|
|
500
|
+
)
|
|
501
|
+
edge_index = edge_index.to(self.device)
|
|
502
|
+
adj = torch.zeros((num_nodes, num_nodes), device=self.device, dtype=torch.float32)
|
|
503
|
+
adj[edge_index[0], edge_index[1]] = 1.0
|
|
504
|
+
deg = adj.sum(dim=1)
|
|
505
|
+
deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
|
|
506
|
+
adj = adj * deg_inv_sqrt.view(-1, 1)
|
|
507
|
+
adj = adj * deg_inv_sqrt.view(1, -1)
|
|
508
|
+
return adj
|
|
509
|
+
|
|
510
|
+
def _tensorize_split(self, X, y, w, allow_none: bool = False):
|
|
511
|
+
if X is None and allow_none:
|
|
512
|
+
return None, None, None
|
|
513
|
+
if not isinstance(X, pd.DataFrame):
|
|
514
|
+
raise ValueError("X must be a pandas DataFrame for GNN.")
|
|
515
|
+
n_rows = len(X)
|
|
516
|
+
if y is not None:
|
|
517
|
+
self._validate_vector(y, "y", n_rows)
|
|
518
|
+
if w is not None:
|
|
519
|
+
self._validate_vector(w, "w", n_rows)
|
|
520
|
+
X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
521
|
+
X, "to_numpy") else np.asarray(X, dtype=np.float32)
|
|
522
|
+
X_tensor = torch.as_tensor(
|
|
523
|
+
X_np, dtype=torch.float32, device=self.device)
|
|
524
|
+
if y is None:
|
|
525
|
+
y_tensor = None
|
|
526
|
+
else:
|
|
527
|
+
y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
528
|
+
y, "to_numpy") else np.asarray(y, dtype=np.float32)
|
|
529
|
+
y_tensor = torch.as_tensor(
|
|
530
|
+
y_np, dtype=torch.float32, device=self.device).view(-1, 1)
|
|
531
|
+
if w is None:
|
|
532
|
+
w_tensor = torch.ones(
|
|
533
|
+
(len(X), 1), dtype=torch.float32, device=self.device)
|
|
534
|
+
else:
|
|
535
|
+
w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
536
|
+
w, "to_numpy") else np.asarray(w, dtype=np.float32)
|
|
537
|
+
w_tensor = torch.as_tensor(
|
|
538
|
+
w_np, dtype=torch.float32, device=self.device).view(-1, 1)
|
|
539
|
+
return X_tensor, y_tensor, w_tensor
|
|
540
|
+
|
|
541
|
+
def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
542
|
+
if not isinstance(X_df, pd.DataFrame):
|
|
543
|
+
raise ValueError("X must be a pandas DataFrame for graph building.")
|
|
544
|
+
meta_expected = None
|
|
545
|
+
cache_key = None
|
|
546
|
+
if self.graph_cache_path:
|
|
547
|
+
meta_expected = self._graph_cache_meta(X_df)
|
|
548
|
+
if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
|
|
549
|
+
cached = self._adj_cache_tensor
|
|
550
|
+
if cached.device != self.device:
|
|
551
|
+
if self.device.type == "mps" and getattr(cached, "is_sparse", False):
|
|
552
|
+
self._adj_cache_tensor = None
|
|
553
|
+
else:
|
|
554
|
+
cached = cached.to(self.device)
|
|
555
|
+
self._adj_cache_tensor = cached
|
|
556
|
+
if self._adj_cache_tensor is not None:
|
|
557
|
+
return self._adj_cache_tensor
|
|
558
|
+
else:
|
|
559
|
+
cache_key = self._graph_cache_key(X_df)
|
|
560
|
+
if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
|
|
561
|
+
cached = self._adj_cache_tensor
|
|
562
|
+
if cached.device != self.device:
|
|
563
|
+
if self.device.type == "mps" and getattr(cached, "is_sparse", False):
|
|
564
|
+
self._adj_cache_tensor = None
|
|
565
|
+
else:
|
|
566
|
+
cached = cached.to(self.device)
|
|
567
|
+
self._adj_cache_tensor = cached
|
|
568
|
+
if self._adj_cache_tensor is not None:
|
|
569
|
+
return self._adj_cache_tensor
|
|
570
|
+
X_np = None
|
|
571
|
+
if X_tensor is None:
|
|
572
|
+
X_np = X_df.to_numpy(dtype=np.float32, copy=False)
|
|
573
|
+
X_tensor = torch.as_tensor(
|
|
574
|
+
X_np, dtype=torch.float32, device=self.device)
|
|
575
|
+
if self.graph_cache_path:
|
|
576
|
+
cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
|
|
577
|
+
if cached is not None:
|
|
578
|
+
self._adj_cache_meta = meta_expected
|
|
579
|
+
self._adj_cache_key = None
|
|
580
|
+
self._adj_cache_tensor = cached
|
|
581
|
+
return cached
|
|
582
|
+
use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
|
|
583
|
+
if use_gpu_knn:
|
|
584
|
+
edge_index = self._build_edge_index_gpu(X_tensor)
|
|
585
|
+
else:
|
|
586
|
+
if X_np is None:
|
|
587
|
+
X_np = X_df.to_numpy(dtype=np.float32, copy=False)
|
|
588
|
+
edge_index = self._build_edge_index_cpu(X_np)
|
|
589
|
+
adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
|
|
590
|
+
if self.graph_cache_path:
|
|
591
|
+
try:
|
|
592
|
+
IOUtils.ensure_parent_dir(str(self.graph_cache_path))
|
|
593
|
+
torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
|
|
594
|
+
except Exception as exc:
|
|
595
|
+
print(
|
|
596
|
+
f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
|
|
597
|
+
self._adj_cache_meta = meta_expected
|
|
598
|
+
self._adj_cache_key = None
|
|
599
|
+
else:
|
|
600
|
+
self._adj_cache_meta = None
|
|
601
|
+
self._adj_cache_key = cache_key
|
|
602
|
+
self._adj_cache_tensor = adj_norm
|
|
603
|
+
return adj_norm
|
|
604
|
+
|
|
605
|
+
def fit(self, X_train, y_train, w_train=None,
|
|
606
|
+
X_val=None, y_val=None, w_val=None,
|
|
607
|
+
trial: Optional[optuna.trial.Trial] = None):
|
|
608
|
+
return self._run_with_mps_fallback(
|
|
609
|
+
self._fit_impl,
|
|
610
|
+
X_train,
|
|
611
|
+
y_train,
|
|
612
|
+
w_train,
|
|
613
|
+
X_val,
|
|
614
|
+
y_val,
|
|
615
|
+
w_val,
|
|
616
|
+
trial,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
def _fit_impl(self, X_train, y_train, w_train=None,
|
|
620
|
+
X_val=None, y_val=None, w_val=None,
|
|
621
|
+
trial: Optional[optuna.trial.Trial] = None):
|
|
622
|
+
X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
|
|
623
|
+
X_train, y_train, w_train, allow_none=False)
|
|
624
|
+
has_val = X_val is not None and y_val is not None
|
|
625
|
+
if has_val:
|
|
626
|
+
X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
|
|
627
|
+
X_val, y_val, w_val, allow_none=False)
|
|
628
|
+
else:
|
|
629
|
+
X_val_tensor = y_val_tensor = w_val_tensor = None
|
|
630
|
+
|
|
631
|
+
adj_train = self._build_graph_from_df(X_train, X_train_tensor)
|
|
632
|
+
adj_val = self._build_graph_from_df(
|
|
633
|
+
X_val, X_val_tensor) if has_val else None
|
|
634
|
+
# DataParallel needs adjacency cached on the model to avoid scatter.
|
|
635
|
+
self._set_adj_buffer(adj_train)
|
|
636
|
+
|
|
637
|
+
base_gnn = self._unwrap_gnn()
|
|
638
|
+
optimizer = torch.optim.Adam(
|
|
639
|
+
base_gnn.parameters(),
|
|
640
|
+
lr=self.learning_rate,
|
|
641
|
+
weight_decay=float(getattr(self, "weight_decay", 0.0)),
|
|
642
|
+
)
|
|
643
|
+
scaler = GradScaler(enabled=(self.device.type == 'cuda'))
|
|
644
|
+
|
|
645
|
+
best_loss = float('inf')
|
|
646
|
+
best_state = None
|
|
647
|
+
patience_counter = 0
|
|
648
|
+
best_epoch = None
|
|
649
|
+
|
|
650
|
+
for epoch in range(1, self.epochs + 1):
|
|
651
|
+
epoch_start_ts = time.time()
|
|
652
|
+
self.gnn.train()
|
|
653
|
+
optimizer.zero_grad()
|
|
654
|
+
with autocast(enabled=(self.device.type == 'cuda')):
|
|
655
|
+
if self.data_parallel_enabled:
|
|
656
|
+
y_pred = self.gnn(X_train_tensor)
|
|
657
|
+
else:
|
|
658
|
+
y_pred = self.gnn(X_train_tensor, adj_train)
|
|
659
|
+
loss = self._compute_weighted_loss(
|
|
660
|
+
y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
|
|
661
|
+
scaler.scale(loss).backward()
|
|
662
|
+
scaler.unscale_(optimizer)
|
|
663
|
+
clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
|
|
664
|
+
scaler.step(optimizer)
|
|
665
|
+
scaler.update()
|
|
666
|
+
|
|
667
|
+
val_loss = None
|
|
668
|
+
if has_val:
|
|
669
|
+
self.gnn.eval()
|
|
670
|
+
if self.data_parallel_enabled and adj_val is not None:
|
|
671
|
+
self._set_adj_buffer(adj_val)
|
|
672
|
+
with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
|
|
673
|
+
if self.data_parallel_enabled:
|
|
674
|
+
y_val_pred = self.gnn(X_val_tensor)
|
|
675
|
+
else:
|
|
676
|
+
y_val_pred = self.gnn(X_val_tensor, adj_val)
|
|
677
|
+
val_loss = self._compute_weighted_loss(
|
|
678
|
+
y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
|
|
679
|
+
if self.data_parallel_enabled:
|
|
680
|
+
# Restore training adjacency.
|
|
681
|
+
self._set_adj_buffer(adj_train)
|
|
682
|
+
|
|
683
|
+
is_best = val_loss is not None and val_loss < best_loss
|
|
684
|
+
best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
|
|
685
|
+
val_loss, best_loss, best_state, patience_counter, base_gnn,
|
|
686
|
+
ignore_keys=["adj_buffer"])
|
|
687
|
+
if is_best:
|
|
688
|
+
best_epoch = epoch
|
|
689
|
+
|
|
690
|
+
prune_now = False
|
|
691
|
+
if trial is not None:
|
|
692
|
+
trial.report(val_loss, epoch)
|
|
693
|
+
if trial.should_prune():
|
|
694
|
+
prune_now = True
|
|
695
|
+
|
|
696
|
+
if dist.is_initialized():
|
|
697
|
+
flag = torch.tensor(
|
|
698
|
+
[1 if prune_now else 0],
|
|
699
|
+
device=self.device,
|
|
700
|
+
dtype=torch.int32,
|
|
701
|
+
)
|
|
702
|
+
dist.broadcast(flag, src=0)
|
|
703
|
+
prune_now = bool(flag.item())
|
|
704
|
+
|
|
705
|
+
if prune_now:
|
|
706
|
+
raise optuna.TrialPruned()
|
|
707
|
+
if stop_training:
|
|
708
|
+
break
|
|
709
|
+
|
|
710
|
+
should_log = (not dist.is_initialized()
|
|
711
|
+
or DistributedUtils.is_main_process())
|
|
712
|
+
if should_log:
|
|
713
|
+
elapsed = int(time.time() - epoch_start_ts)
|
|
714
|
+
if val_loss is None:
|
|
715
|
+
print(
|
|
716
|
+
f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
|
|
717
|
+
flush=True,
|
|
718
|
+
)
|
|
719
|
+
else:
|
|
720
|
+
print(
|
|
721
|
+
f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
|
|
722
|
+
f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
|
|
723
|
+
flush=True,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
if best_state is not None:
|
|
727
|
+
base_gnn.load_state_dict(best_state, strict=False)
|
|
728
|
+
self.best_epoch = int(best_epoch or self.epochs)
|
|
729
|
+
|
|
730
|
+
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
731
|
+
return self._run_with_mps_fallback(self._predict_impl, X)
|
|
732
|
+
|
|
733
|
+
def _predict_impl(self, X: pd.DataFrame) -> np.ndarray:
|
|
734
|
+
self.gnn.eval()
|
|
735
|
+
X_tensor, _, _ = self._tensorize_split(
|
|
736
|
+
X, None, None, allow_none=False)
|
|
737
|
+
adj = self._build_graph_from_df(X, X_tensor)
|
|
738
|
+
if self.data_parallel_enabled:
|
|
739
|
+
self._set_adj_buffer(adj)
|
|
740
|
+
inference_cm = getattr(torch, "inference_mode", torch.no_grad)
|
|
741
|
+
with inference_cm():
|
|
742
|
+
if self.data_parallel_enabled:
|
|
743
|
+
y_pred = self.gnn(X_tensor).cpu().numpy()
|
|
744
|
+
else:
|
|
745
|
+
y_pred = self.gnn(X_tensor, adj).cpu().numpy()
|
|
746
|
+
if self.task_type == 'classification':
|
|
747
|
+
y_pred = 1 / (1 + np.exp(-y_pred))
|
|
748
|
+
else:
|
|
749
|
+
y_pred = np.clip(y_pred, 1e-6, None)
|
|
750
|
+
return y_pred.ravel()
|
|
751
|
+
|
|
752
|
+
def encode(self, X: pd.DataFrame) -> np.ndarray:
|
|
753
|
+
return self._run_with_mps_fallback(self._encode_impl, X)
|
|
754
|
+
|
|
755
|
+
def _encode_impl(self, X: pd.DataFrame) -> np.ndarray:
|
|
756
|
+
"""Return per-sample node embeddings (hidden representations)."""
|
|
757
|
+
base = self._unwrap_gnn()
|
|
758
|
+
base.eval()
|
|
759
|
+
X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
|
|
760
|
+
adj = self._build_graph_from_df(X, X_tensor)
|
|
761
|
+
if self.data_parallel_enabled:
|
|
762
|
+
self._set_adj_buffer(adj)
|
|
763
|
+
inference_cm = getattr(torch, "inference_mode", torch.no_grad)
|
|
764
|
+
with inference_cm():
|
|
765
|
+
h = X_tensor
|
|
766
|
+
layers = getattr(base, "layers", None)
|
|
767
|
+
if layers is None:
|
|
768
|
+
raise RuntimeError("GNN base module does not expose layers.")
|
|
769
|
+
for layer in layers:
|
|
770
|
+
h = layer(h, adj)
|
|
771
|
+
h = _adj_mm(adj, h)
|
|
772
|
+
return h.detach().cpu().numpy()
|
|
773
|
+
|
|
774
|
+
def set_params(self, params: Dict[str, Any]):
|
|
775
|
+
for key, value in params.items():
|
|
776
|
+
if hasattr(self, key):
|
|
777
|
+
setattr(self, key, value)
|
|
778
|
+
else:
|
|
779
|
+
raise ValueError(f"Parameter {key} not found in GNN model.")
|
|
780
|
+
# Rebuild the backbone after structural parameter changes.
|
|
781
|
+
self.gnn = SimpleGNN(
|
|
782
|
+
input_dim=self.input_dim,
|
|
783
|
+
hidden_dim=self.hidden_dim,
|
|
784
|
+
num_layers=self.num_layers,
|
|
785
|
+
dropout=self.dropout,
|
|
786
|
+
task_type=self.task_type
|
|
787
|
+
).to(self.device)
|
|
788
|
+
return self
|