ltfmselector 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ltfmselector/logger.py ADDED
@@ -0,0 +1,39 @@
1
+ import numpy as np
2
+
3
+ class Logger:
4
+ def __init__(self, env, max_steps=500):
5
+ self.state_dim = env.X.shape[1]*2 + 1
6
+ self.max_steps = max_steps
7
+
8
+ # Pre-allocate the "scratchpad" for the current episode
9
+ self.s_buffer = np.zeros((self.max_steps, self.state_dim), dtype=np.float32)
10
+ self.a_buffer = np.zeros((self.max_steps,), dtype=np.int32)
11
+ self.ptr = 0
12
+
13
+ # Final storage lists
14
+ self.all_states = []
15
+ self.all_actions = []
16
+
17
+ def log_step(self, obs, action):
18
+ if self.ptr < self.max_steps:
19
+ self.s_buffer[self.ptr] = obs
20
+ self.a_buffer[self.ptr] = action
21
+ self.ptr += 1
22
+
23
+ def log_episode(self):
24
+ """Internal helper to copy the current buffer into the main list."""
25
+ if self.ptr > 0:
26
+ self.all_states.append(self.s_buffer[:self.ptr].copy())
27
+ self.all_actions.append(self.a_buffer[:self.ptr].copy())
28
+ self.ptr = 0
29
+
30
+ def save_data(self, filename):
31
+ """Saves all completed and currently-running episodes to disk."""
32
+ # Capture the current episode if it's mid-run
33
+ self.log_episode()
34
+ np.savez_compressed(
35
+ filename,
36
+ states=np.array(self.all_states, dtype=object),
37
+ actions=np.array(self.all_actions, dtype=object)
38
+ )
39
+ print(f"Saved {len(self.all_states)} episodes (states + actions) to {filename}")
@@ -11,6 +11,7 @@ import numpy as np
11
11
  import pandas as pd
12
12
  from collections import defaultdict
13
13
 
14
+ from .logger import Logger
14
15
  from .env import Environment
15
16
  from .utils import ReplayMemory, DQN, Transition
16
17
 
@@ -110,25 +111,25 @@ class LTFMSelector:
110
111
  Cost of querying a feature.
111
112
 
112
113
  fQueryFunction : None or {'step', 'linear', 'quadratic'}
113
- User can also decide to progressively increase the cost of
114
+ User can also decide to progressively increase the cost of
114
115
  querying features in the following manner:
115
- 'step' :
116
- Every additional feature adds a fixed constant, determined
116
+ 'step' :
117
+ Every additional feature adds a fixed constant, determined
117
118
  by user.
118
- 'linear' :
119
- Cost of every additional feature linearly increases according
119
+ 'linear' :
120
+ Cost of every additional feature linearly increases according
120
121
  to user-defined gradient
121
- 'quadratic' :
122
- Cost of every additional feature increases quadratically,
122
+ 'quadratic' :
123
+ Cost of every additional feature increases quadratically,
123
124
  according to a user-defined rate
124
125
 
125
126
  fThreshold : None or int
126
127
  If `fQueryFunction == {'step', 'linear', 'quadratic', 'exponential'}`
127
- Threshold of number of features, before cost of recruiting
128
+ Threshold of number of features, before cost of recruiting
128
129
  increases
129
-
130
+
130
131
  fCap : None or float
131
- If `fQueryFunction == {'step', 'linear', 'quadratic'}`, upper
132
+ If `fQueryFunction == {'step', 'linear', 'quadratic'}`, upper
132
133
  limit of penalty
133
134
 
134
135
  fRate : None or float
@@ -284,7 +285,7 @@ class LTFMSelector:
284
285
  def fit(
285
286
  self, X, y, loss_function='mse', sample_weight=None,
286
287
  agent_neuralnetwork=None, lr=1e-5, returnQ=False,
287
- monitor=False, background_dataset=None, **kwargs
288
+ monitor=False, log=False, background_dataset=None, **kwargs
288
289
  ):
289
290
  '''
290
291
  Initializes the environment and agent, then trains the agent to select
@@ -323,15 +324,18 @@ class LTFMSelector:
323
324
 
324
325
  returnQ : bool
325
326
  Return average computed action-value functions and rewards of
326
- the sampled batches, as a (<total_iterations>, 3) matrix. The
327
+ the sampled batches, as a (<total_iterations>, 3) matrix. The
327
328
  columns correspond to the averaged Q, reward, and target functions.
328
329
 
329
330
  monitor : bool
330
331
  Monitor training process using a TensorBoard.
331
332
 
332
- Run `tensorboard --logdir=runs` in the terminal to monitor the p
333
+ Run `tensorboard --logdir=runs` in the terminal to monitor the
333
334
  progression of the action-value function.
334
335
 
336
+ log : bool
337
+ Log states and actions for research purposes
338
+
335
339
  background_dataset : None or pd.DataFrame
336
340
  If None, numerical features will be assumed when computing the
337
341
  background dataset.
@@ -403,6 +407,9 @@ class LTFMSelector:
403
407
  )
404
408
  env.reset()
405
409
 
410
+ if log:
411
+ logger = Logger(env, max_steps=self.max_timesteps)
412
+
406
413
  # Initializing length of state and actions as public fields for
407
414
  # loading the model later
408
415
  self.state_length = len(env.state)
@@ -465,6 +472,9 @@ class LTFMSelector:
465
472
  # - reward
466
473
  observation, reward, terminated = env.step(action.item())
467
474
 
475
+ if log:
476
+ logger.log_step(observation, action)
477
+
468
478
  if terminated:
469
479
  next_state = None
470
480
  else:
@@ -529,6 +539,9 @@ class LTFMSelector:
529
539
  )
530
540
  break
531
541
 
542
+ if log:
543
+ logger.log_episode()
544
+
532
545
  # Saving trained policy network intermediately
533
546
  if not self.checkpoint_interval is None:
534
547
  if (i_episode + 1) % self.checkpoint_interval == 0:
@@ -539,6 +552,9 @@ class LTFMSelector:
539
552
  self.policy_network_checkpoints[self.episodes] =\
540
553
  self.policy_net.state_dict()
541
554
 
555
+ if log:
556
+ logger.save_data("ActionStates_fromFit.npz")
557
+
542
558
  if monitor:
543
559
  writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
544
560
  writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
@@ -556,7 +572,7 @@ class LTFMSelector:
556
572
  else:
557
573
  return doc
558
574
 
559
- def predict(self, X_test, **kwargs):
575
+ def predict(self, X_test, log=False, **kwargs):
560
576
  '''
561
577
  Use trained agent to select features and a suitable prediction model
562
578
  to predict the target/class, given X_test.
@@ -566,6 +582,9 @@ class LTFMSelector:
566
582
  X_test : pd.DataFrame
567
583
  Test samples
568
584
 
585
+ log : bool
586
+ Log states and actions for research purposes
587
+
569
588
  Returns
570
589
  -------
571
590
  y_pred : array
@@ -584,6 +603,8 @@ class LTFMSelector:
584
603
  self.pType, self.regression_tol, self.regression_error_rounding,
585
604
  self.pModels, self.device, **kwargs
586
605
  )
606
+ if log:
607
+ logger = Logger(env, max_steps=self.max_timesteps)
587
608
 
588
609
  # Create dictionary to save information per episode
589
610
  doc_test = defaultdict(dict)
@@ -608,6 +629,9 @@ class LTFMSelector:
608
629
 
609
630
  observation, reward, terminated = env.step(action.item())
610
631
 
632
+ if log:
633
+ logger.log_step(observation, action)
634
+
611
635
  if terminated:
612
636
  next_state = None
613
637
  else:
@@ -637,6 +661,12 @@ class LTFMSelector:
637
661
  y_pred[i] = env.y_pred
638
662
  break
639
663
 
664
+ if log:
665
+ logger.log_episode()
666
+
667
+ if log:
668
+ logger.save_data("ActionStates_fromPredict.npz")
669
+
640
670
  return y_pred, doc_test
641
671
 
642
672
  def select_action(self, state, env):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ltfmselector
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: Locally-Tailored Feature and Model Selector with Deep Q-Learning
5
5
  Project-URL: GitHub, https://github.com/RenZhen95/ltfmselector/
6
6
  Author-email: RenZhen95 <j-liaw@hotmail.com>
@@ -29,9 +29,11 @@ License-File: LICENSE
29
29
  Requires-Python: >=3.12
30
30
  Requires-Dist: gymnasium>=1.1.1
31
31
  Requires-Dist: matplotlib>=3.10.1
32
+ Requires-Dist: moviepy>=2.2.1
32
33
  Requires-Dist: numpy>=2.2.4
33
34
  Requires-Dist: openpyxl>=3.1.5
34
35
  Requires-Dist: pandas>=2.2.3
36
+ Requires-Dist: pygame>=2.6.1
35
37
  Requires-Dist: scikit-learn<1.6
36
38
  Requires-Dist: seaborn>=0.13.2
37
39
  Requires-Dist: tensorboard>=2.20.0
@@ -0,0 +1,10 @@
1
+ ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
2
+ ltfmselector/env.py,sha256=F0NycqUkNn-p2zC1EPdds73__G8yyMW5f9F93yPDHTA,16371
3
+ ltfmselector/logger.py,sha256=of5fgVmh1CctRE3ckjO0R_Wo6WFg8Mg1RxAI84oPKVA,1448
4
+ ltfmselector/ltfmselector.py,sha256=jBd7dj8_l7PdTtkm1lALTzc6JX9Q954eKZrSk59EfII,33151
5
+ ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
7
+ ltfmselector-0.2.2.dist-info/METADATA,sha256=8AFbJExeTNTjSUvNbuw_wDo4NKfWkj0huNY599I4i20,3079
8
+ ltfmselector-0.2.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
9
+ ltfmselector-0.2.2.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
10
+ ltfmselector-0.2.2.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
2
- ltfmselector/env.py,sha256=F0NycqUkNn-p2zC1EPdds73__G8yyMW5f9F93yPDHTA,16371
3
- ltfmselector/ltfmselector.py,sha256=GdxazN6JG_ELZ7a7x6bbBCVDsgmdRWOsbi-TnRFNk8Y,32354
4
- ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
6
- ltfmselector-0.2.1.dist-info/METADATA,sha256=mHEsAKWtYYsGOQyRxHX5f8fAaSXePDDklL_CLzMay9A,3020
7
- ltfmselector-0.2.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
8
- ltfmselector-0.2.1.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
9
- ltfmselector-0.2.1.dist-info/RECORD,,