ltfmselector 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.optim as optim
4
+ from torch.utils.tensorboard import SummaryWriter
4
5
 
5
6
  import os
6
7
  import random
@@ -226,7 +227,7 @@ class LTFMSelector:
226
227
  def fit(
227
228
  self, X, y, loss_function='mse', sample_weight=None,
228
229
  agent_neuralnetwork=None, lr=1e-5, returnQ=False,
229
- background_dataset=None, **kwargs
230
+ monitor=False, background_dataset=None, **kwargs
230
231
  ):
231
232
  '''
232
233
  Initializes the environment and agent, then trains the agent to select
@@ -265,7 +266,14 @@ class LTFMSelector:
265
266
 
266
267
  returnQ : bool
267
268
  Return average computed action-value functions and rewards of
268
- the sampled batches, for debugging purposes.
269
+ the sampled batches, as a (<total_iterations>, 3) matrix. The
270
+ columns correspond to the averaged Q, reward, and target functions.
271
+
272
+ monitor : bool
273
+ Monitor training process using a TensorBoard.
274
+
275
+ Run `tensorboard --logdir=runs` in the terminal to monitor the p
276
+ progression of the action-value function.
269
277
 
270
278
  background_dataset : None or pd.DataFrame
271
279
  If None, numerical features will be assumed when computing the
@@ -312,10 +320,18 @@ class LTFMSelector:
312
320
  self.sample_weight = sample_weight
313
321
 
314
322
  # If user wants to monitor progression of terms in the loss function
323
+ if monitor:
324
+ writer = SummaryWriter()
325
+ monitor_count = 1
326
+
327
+ # If user wants to save average computed action-value functions and
328
+ # rewards of sampled batches
315
329
  if returnQ:
316
- Q_avr_list = []
317
- r_avr_list = []
318
- V_avr_list = []
330
+ total_iterations = 10000000000
331
+ LearningValuesMatrix = np.zeros(
332
+ (total_iterations, 3), dtype=np.float32
333
+ )
334
+ Q_count = 1
319
335
 
320
336
  # Initializing the environment
321
337
  env = Environment(
@@ -406,13 +422,21 @@ class LTFMSelector:
406
422
  state = next_state
407
423
 
408
424
  # Optimize the model
409
- _res = self.optimize_model(optimizer, loss_function, returnQ)
425
+ _res = self.optimize_model(optimizer, loss_function, monitor, returnQ)
426
+
427
+ if monitor:
428
+ if not _res is None:
429
+ writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
430
+ writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
431
+ writer.add_scalar("Metrics/Average_Target", _res[2], monitor_count)
432
+ monitor_count += 1
410
433
 
411
434
  if returnQ:
412
435
  if not _res is None:
413
- Q_avr_list.append(_res[0])
414
- r_avr_list.append(_res[1])
415
- V_avr_list.append(_res[2])
436
+ LearningValuesMatrix[Q_count, 0] = _res[0]
437
+ LearningValuesMatrix[Q_count, 1] = _res[1]
438
+ LearningValuesMatrix[Q_count, 2] = _res[2]
439
+ Q_count += 1
416
440
 
417
441
  # Apply soft update to target network's weights
418
442
  targetParameters = self.target_net.state_dict()
@@ -456,11 +480,19 @@ class LTFMSelector:
456
480
  self.policy_network_checkpoints[self.episodes] =\
457
481
  self.policy_net.state_dict()
458
482
 
483
+ if monitor:
484
+ writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
485
+ writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
486
+ writer.add_scalar("Metrics/Average_Target", _res[2], monitor_count)
487
+ writer.close()
488
+
459
489
  if returnQ:
460
- Q_avr_list.append(_res[0])
461
- r_avr_list.append(_res[1])
462
- V_avr_list.append(_res[2])
463
- return doc, (Q_avr_list, r_avr_list, V_avr_list)
490
+ LearningValuesMatrix[Q_count, 0] = _res[0]
491
+ LearningValuesMatrix[Q_count, 1] = _res[1]
492
+ LearningValuesMatrix[Q_count, 2] = _res[2]
493
+
494
+ if (monitor or returnQ):
495
+ return doc, LearningValuesMatrix[0:Q_count+1, :]
464
496
  else:
465
497
  return doc
466
498
 
@@ -572,7 +604,7 @@ class LTFMSelector:
572
604
  with torch.no_grad():
573
605
  return (self.policy_net(state).max(1)[1].view(1, 1) - 1)
574
606
 
575
- def optimize_model(self, optimizer, loss_function, returnQ):
607
+ def optimize_model(self, optimizer, loss_function, monitor, returnQ):
576
608
  '''
577
609
  Optimize the policy network.
578
610
 
@@ -698,7 +730,7 @@ class LTFMSelector:
698
730
  # Optimize the model (policy network)
699
731
  optimizer.step()
700
732
 
701
- if returnQ:
733
+ if (monitor or returnQ):
702
734
  Q_avr = state_action_values.detach().numpy().mean()
703
735
  r_avr = reward_batch.unsqueeze(1).numpy().mean()
704
736
  V_avr = expected_state_action_values.unsqueeze(1).numpy().mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ltfmselector
3
- Version: 0.1.11
3
+ Version: 0.1.12
4
4
  Summary: Locally-Tailored Feature and Model Selector with Deep Q-Learning
5
5
  Project-URL: GitHub, https://github.com/RenZhen95/ltfmselector/
6
6
  Author-email: RenZhen95 <j-liaw@hotmail.com>
@@ -32,8 +32,9 @@ Requires-Dist: matplotlib>=3.10.1
32
32
  Requires-Dist: numpy>=2.2.4
33
33
  Requires-Dist: openpyxl>=3.1.5
34
34
  Requires-Dist: pandas>=2.2.3
35
- Requires-Dist: scikit-learn>=1.6.1
35
+ Requires-Dist: scikit-learn<1.6
36
36
  Requires-Dist: seaborn>=0.13.2
37
+ Requires-Dist: tensorboard>=2.20.0
37
38
  Requires-Dist: torch>=2.6.0
38
39
  Description-Content-Type: text/markdown
39
40
 
@@ -1,9 +1,9 @@
1
1
  ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
2
2
  ltfmselector/env.py,sha256=vizWGqDSc_2Zfs9aXjFARanIAz6PTKwUHu2_Lew9s3Y,13878
3
- ltfmselector/ltfmselector.py,sha256=-uRYcj89l8GL5re5Zw0mxe9tp8ulSmSBawqLA96S5A8,27984
3
+ ltfmselector/ltfmselector.py,sha256=vs9unOmoDKq1piV6t87GC1wdy7kP8ucKHihw6i0F4KI,29567
4
4
  ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
6
- ltfmselector-0.1.11.dist-info/METADATA,sha256=rPl6VeICmX_bw4B5QEZLlsXJcMVJ2-xGfJ9cZwTe3oA,2989
7
- ltfmselector-0.1.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
8
- ltfmselector-0.1.11.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
9
- ltfmselector-0.1.11.dist-info/RECORD,,
6
+ ltfmselector-0.1.12.dist-info/METADATA,sha256=QaUPeSx9NlZx0ZUbkEPRyFS-8nfJz9Y8yV5TXXPc7fA,3021
7
+ ltfmselector-0.1.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
8
+ ltfmselector-0.1.12.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
9
+ ltfmselector-0.1.12.dist-info/RECORD,,