ltfmselector 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ltfmselector/ltfmselector.py +47 -15
- {ltfmselector-0.1.11.dist-info → ltfmselector-0.1.12.dist-info}/METADATA +3 -2
- {ltfmselector-0.1.11.dist-info → ltfmselector-0.1.12.dist-info}/RECORD +5 -5
- {ltfmselector-0.1.11.dist-info → ltfmselector-0.1.12.dist-info}/WHEEL +0 -0
- {ltfmselector-0.1.11.dist-info → ltfmselector-0.1.12.dist-info}/licenses/LICENSE +0 -0
ltfmselector/ltfmselector.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
import torch.optim as optim
|
|
4
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
4
5
|
|
|
5
6
|
import os
|
|
6
7
|
import random
|
|
@@ -226,7 +227,7 @@ class LTFMSelector:
|
|
|
226
227
|
def fit(
|
|
227
228
|
self, X, y, loss_function='mse', sample_weight=None,
|
|
228
229
|
agent_neuralnetwork=None, lr=1e-5, returnQ=False,
|
|
229
|
-
background_dataset=None, **kwargs
|
|
230
|
+
monitor=False, background_dataset=None, **kwargs
|
|
230
231
|
):
|
|
231
232
|
'''
|
|
232
233
|
Initializes the environment and agent, then trains the agent to select
|
|
@@ -265,7 +266,14 @@ class LTFMSelector:
|
|
|
265
266
|
|
|
266
267
|
returnQ : bool
|
|
267
268
|
Return average computed action-value functions and rewards of
|
|
268
|
-
the sampled batches,
|
|
269
|
+
the sampled batches, as a (<total_iterations>, 3) matrix. The
|
|
270
|
+
columns correspond to the averaged Q, reward, and target functions.
|
|
271
|
+
|
|
272
|
+
monitor : bool
|
|
273
|
+
Monitor training process using a TensorBoard.
|
|
274
|
+
|
|
275
|
+
Run `tensorboard --logdir=runs` in the terminal to monitor the p
|
|
276
|
+
progression of the action-value function.
|
|
269
277
|
|
|
270
278
|
background_dataset : None or pd.DataFrame
|
|
271
279
|
If None, numerical features will be assumed when computing the
|
|
@@ -312,10 +320,18 @@ class LTFMSelector:
|
|
|
312
320
|
self.sample_weight = sample_weight
|
|
313
321
|
|
|
314
322
|
# If user wants to monitor progression of terms in the loss function
|
|
323
|
+
if monitor:
|
|
324
|
+
writer = SummaryWriter()
|
|
325
|
+
monitor_count = 1
|
|
326
|
+
|
|
327
|
+
# If user wants to save average computed action-value functions and
|
|
328
|
+
# rewards of sampled batches
|
|
315
329
|
if returnQ:
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
330
|
+
total_iterations = 10000000000
|
|
331
|
+
LearningValuesMatrix = np.zeros(
|
|
332
|
+
(total_iterations, 3), dtype=np.float32
|
|
333
|
+
)
|
|
334
|
+
Q_count = 1
|
|
319
335
|
|
|
320
336
|
# Initializing the environment
|
|
321
337
|
env = Environment(
|
|
@@ -406,13 +422,21 @@ class LTFMSelector:
|
|
|
406
422
|
state = next_state
|
|
407
423
|
|
|
408
424
|
# Optimize the model
|
|
409
|
-
_res = self.optimize_model(optimizer, loss_function, returnQ)
|
|
425
|
+
_res = self.optimize_model(optimizer, loss_function, monitor, returnQ)
|
|
426
|
+
|
|
427
|
+
if monitor:
|
|
428
|
+
if not _res is None:
|
|
429
|
+
writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
|
|
430
|
+
writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
|
|
431
|
+
writer.add_scalar("Metrics/Average_Target", _res[2], monitor_count)
|
|
432
|
+
monitor_count += 1
|
|
410
433
|
|
|
411
434
|
if returnQ:
|
|
412
435
|
if not _res is None:
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
436
|
+
LearningValuesMatrix[Q_count, 0] = _res[0]
|
|
437
|
+
LearningValuesMatrix[Q_count, 1] = _res[1]
|
|
438
|
+
LearningValuesMatrix[Q_count, 2] = _res[2]
|
|
439
|
+
Q_count += 1
|
|
416
440
|
|
|
417
441
|
# Apply soft update to target network's weights
|
|
418
442
|
targetParameters = self.target_net.state_dict()
|
|
@@ -456,11 +480,19 @@ class LTFMSelector:
|
|
|
456
480
|
self.policy_network_checkpoints[self.episodes] =\
|
|
457
481
|
self.policy_net.state_dict()
|
|
458
482
|
|
|
483
|
+
if monitor:
|
|
484
|
+
writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
|
|
485
|
+
writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
|
|
486
|
+
writer.add_scalar("Metrics/Average_Target", _res[2], monitor_count)
|
|
487
|
+
writer.close()
|
|
488
|
+
|
|
459
489
|
if returnQ:
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
490
|
+
LearningValuesMatrix[Q_count, 0] = _res[0]
|
|
491
|
+
LearningValuesMatrix[Q_count, 1] = _res[1]
|
|
492
|
+
LearningValuesMatrix[Q_count, 2] = _res[2]
|
|
493
|
+
|
|
494
|
+
if (monitor or returnQ):
|
|
495
|
+
return doc, LearningValuesMatrix[0:Q_count+1, :]
|
|
464
496
|
else:
|
|
465
497
|
return doc
|
|
466
498
|
|
|
@@ -572,7 +604,7 @@ class LTFMSelector:
|
|
|
572
604
|
with torch.no_grad():
|
|
573
605
|
return (self.policy_net(state).max(1)[1].view(1, 1) - 1)
|
|
574
606
|
|
|
575
|
-
def optimize_model(self, optimizer, loss_function, returnQ):
|
|
607
|
+
def optimize_model(self, optimizer, loss_function, monitor, returnQ):
|
|
576
608
|
'''
|
|
577
609
|
Optimize the policy network.
|
|
578
610
|
|
|
@@ -698,7 +730,7 @@ class LTFMSelector:
|
|
|
698
730
|
# Optimize the model (policy network)
|
|
699
731
|
optimizer.step()
|
|
700
732
|
|
|
701
|
-
if returnQ:
|
|
733
|
+
if (monitor or returnQ):
|
|
702
734
|
Q_avr = state_action_values.detach().numpy().mean()
|
|
703
735
|
r_avr = reward_batch.unsqueeze(1).numpy().mean()
|
|
704
736
|
V_avr = expected_state_action_values.unsqueeze(1).numpy().mean()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ltfmselector
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.12
|
|
4
4
|
Summary: Locally-Tailored Feature and Model Selector with Deep Q-Learning
|
|
5
5
|
Project-URL: GitHub, https://github.com/RenZhen95/ltfmselector/
|
|
6
6
|
Author-email: RenZhen95 <j-liaw@hotmail.com>
|
|
@@ -32,8 +32,9 @@ Requires-Dist: matplotlib>=3.10.1
|
|
|
32
32
|
Requires-Dist: numpy>=2.2.4
|
|
33
33
|
Requires-Dist: openpyxl>=3.1.5
|
|
34
34
|
Requires-Dist: pandas>=2.2.3
|
|
35
|
-
Requires-Dist: scikit-learn
|
|
35
|
+
Requires-Dist: scikit-learn<1.6
|
|
36
36
|
Requires-Dist: seaborn>=0.13.2
|
|
37
|
+
Requires-Dist: tensorboard>=2.20.0
|
|
37
38
|
Requires-Dist: torch>=2.6.0
|
|
38
39
|
Description-Content-Type: text/markdown
|
|
39
40
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
|
|
2
2
|
ltfmselector/env.py,sha256=vizWGqDSc_2Zfs9aXjFARanIAz6PTKwUHu2_Lew9s3Y,13878
|
|
3
|
-
ltfmselector/ltfmselector.py,sha256
|
|
3
|
+
ltfmselector/ltfmselector.py,sha256=vs9unOmoDKq1piV6t87GC1wdy7kP8ucKHihw6i0F4KI,29567
|
|
4
4
|
ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
|
|
6
|
-
ltfmselector-0.1.
|
|
7
|
-
ltfmselector-0.1.
|
|
8
|
-
ltfmselector-0.1.
|
|
9
|
-
ltfmselector-0.1.
|
|
6
|
+
ltfmselector-0.1.12.dist-info/METADATA,sha256=QaUPeSx9NlZx0ZUbkEPRyFS-8nfJz9Y8yV5TXXPc7fA,3021
|
|
7
|
+
ltfmselector-0.1.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
8
|
+
ltfmselector-0.1.12.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
|
|
9
|
+
ltfmselector-0.1.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|