ltfmselector 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ltfmselector/env.py CHANGED
@@ -364,6 +364,7 @@ class Environment:
364
364
  # Get number of total recruited features
365
365
  nFSubset = (self.get_feature_mask()).sum()
366
366
 
367
+ # DEV:: If more than 10 statements, implement dictionary instead
367
368
  if self.fQueryFunction == "step":
368
369
  return self.get_fQueryCostStep(nFSubset)
369
370
  elif self.fQueryFunction == "linear":
@@ -380,17 +381,26 @@ class Environment:
380
381
 
381
382
  def get_fQueryCostLinear(self, _nFSubset):
382
383
  '''Linear function for querying feature'''
383
- return max(
384
+ _qC = max(
384
385
  self.fQueryCost,
385
386
  self.fQueryCost + self.fRate*(_nFSubset-self.fThreshold)
386
387
  )
388
+ if not self.fCap is None:
389
+ return min(self.fCap, _qC)
390
+ else:
391
+ return _qC
387
392
 
388
393
  def get_fQueryCostQuadratic(self, _nFSubset):
389
394
  '''Quadratic function for querying feature'''
390
395
  if _nFSubset > self.fThreshold:
391
- return self.fQueryCost + self.fRate*(_nFSubset-self.fThreshold)**2
396
+ _qC = self.fQueryCost + self.fRate*(_nFSubset-self.fThreshold)**2
392
397
  else:
393
- return self.fQueryCost
398
+ _qC = self.fQueryCost
399
+
400
+ if not self.fCap is None:
401
+ return min(self.fCap, _qC)
402
+ else:
403
+ return _qC
394
404
 
395
405
  def get_feature_mask(self):
396
406
  '''
ltfmselector/logger.py ADDED
@@ -0,0 +1,39 @@
1
+ import numpy as np
2
+
3
+ class Logger:
4
+ def __init__(self, env, max_steps=500):
5
+ self.state_dim = env.X.shape[1]*2 + 1
6
+ self.max_steps = max_steps
7
+
8
+ # Pre-allocate the "scratchpad" for the current episode
9
+ self.s_buffer = np.zeros((self.max_steps, self.state_dim), dtype=np.float32)
10
+ self.a_buffer = np.zeros((self.max_steps,), dtype=np.int32)
11
+ self.ptr = 0
12
+
13
+ # Final storage lists
14
+ self.all_states = []
15
+ self.all_actions = []
16
+
17
+ def log_step(self, obs, action):
18
+ if self.ptr < self.max_steps:
19
+ self.s_buffer[self.ptr] = obs
20
+ self.a_buffer[self.ptr] = action
21
+ self.ptr += 1
22
+
23
+ def log_episode(self):
24
+ """Internal helper to copy the current buffer into the main list."""
25
+ if self.ptr > 0:
26
+ self.all_states.append(self.s_buffer[:self.ptr].copy())
27
+ self.all_actions.append(self.a_buffer[:self.ptr].copy())
28
+ self.ptr = 0
29
+
30
+ def save_data(self, filename):
31
+ """Saves all completed and currently-running episodes to disk."""
32
+ # Capture the current episode if it's mid-run
33
+ self.log_episode()
34
+ np.savez_compressed(
35
+ filename,
36
+ states=np.array(self.all_states, dtype=object),
37
+ actions=np.array(self.all_actions, dtype=object)
38
+ )
39
+ print(f"Saved {len(self.all_states)} episodes (states + actions) to {filename}")
@@ -11,6 +11,7 @@ import numpy as np
11
11
  import pandas as pd
12
12
  from collections import defaultdict
13
13
 
14
+ from .logger import Logger
14
15
  from .env import Environment
15
16
  from .utils import ReplayMemory, DQN, Transition
16
17
 
@@ -110,28 +111,29 @@ class LTFMSelector:
110
111
  Cost of querying a feature.
111
112
 
112
113
  fQueryFunction : None or {'step', 'linear', 'quadratic'}
113
- User can also decide to progressively increase the cost of
114
+ User can also decide to progressively increase the cost of
114
115
  querying features in the following manner:
115
- 'step' :
116
- Every additional feature adds a fixed constant, determined
116
+ 'step' :
117
+ Every additional feature adds a fixed constant, determined
117
118
  by user.
118
- 'linear' :
119
- Cost of every additional feature linearly increases according
119
+ 'linear' :
120
+ Cost of every additional feature linearly increases according
120
121
  to user-defined gradient
121
- 'quadratic' :
122
- Cost of every additional feature increases quadratically,
122
+ 'quadratic' :
123
+ Cost of every additional feature increases quadratically,
123
124
  according to a user-defined rate
124
125
 
125
126
  fThreshold : None or int
126
127
  If `fQueryFunction == {'step', 'linear', 'quadratic', 'exponential'}`
127
- Threshold of number of features, before cost of recruiting
128
+ Threshold of number of features, before cost of recruiting
128
129
  increases
129
-
130
+
130
131
  fCap : None or float
131
- If `fQueryFunction == {'step'}`, upper limit of penalty
132
+ If `fQueryFunction == {'step', 'linear', 'quadratic'}`, upper
133
+ limit of penalty
132
134
 
133
135
  fRate : None or float
134
- If `fQueryFunction == {'linear', 'quadratic', 'exponential'}`, rate
136
+ If `fQueryFunction == {'linear', 'quadratic'}`, rate of
135
137
  individual cost functions
136
138
 
137
139
  mQueryCost : float
@@ -283,7 +285,7 @@ class LTFMSelector:
283
285
  def fit(
284
286
  self, X, y, loss_function='mse', sample_weight=None,
285
287
  agent_neuralnetwork=None, lr=1e-5, returnQ=False,
286
- monitor=False, background_dataset=None, **kwargs
288
+ monitor=False, log=False, background_dataset=None, **kwargs
287
289
  ):
288
290
  '''
289
291
  Initializes the environment and agent, then trains the agent to select
@@ -322,15 +324,18 @@ class LTFMSelector:
322
324
 
323
325
  returnQ : bool
324
326
  Return average computed action-value functions and rewards of
325
- the sampled batches, as a (<total_iterations>, 3) matrix. The
327
+ the sampled batches, as a (<total_iterations>, 3) matrix. The
326
328
  columns correspond to the averaged Q, reward, and target functions.
327
329
 
328
330
  monitor : bool
329
331
  Monitor training process using a TensorBoard.
330
332
 
331
- Run `tensorboard --logdir=runs` in the terminal to monitor the p
333
+ Run `tensorboard --logdir=runs` in the terminal to monitor the
332
334
  progression of the action-value function.
333
335
 
336
+ log : bool
337
+ Log states and actions for research purposes
338
+
334
339
  background_dataset : None or pd.DataFrame
335
340
  If None, numerical features will be assumed when computing the
336
341
  background dataset.
@@ -402,6 +407,9 @@ class LTFMSelector:
402
407
  )
403
408
  env.reset()
404
409
 
410
+ if log:
411
+ logger = Logger(env, max_steps=self.max_timesteps)
412
+
405
413
  # Initializing length of state and actions as public fields for
406
414
  # loading the model later
407
415
  self.state_length = len(env.state)
@@ -464,6 +472,9 @@ class LTFMSelector:
464
472
  # - reward
465
473
  observation, reward, terminated = env.step(action.item())
466
474
 
475
+ if log:
476
+ logger.log_step(observation, action)
477
+
467
478
  if terminated:
468
479
  next_state = None
469
480
  else:
@@ -528,6 +539,9 @@ class LTFMSelector:
528
539
  )
529
540
  break
530
541
 
542
+ if log:
543
+ logger.log_episode()
544
+
531
545
  # Saving trained policy network intermediately
532
546
  if not self.checkpoint_interval is None:
533
547
  if (i_episode + 1) % self.checkpoint_interval == 0:
@@ -538,6 +552,9 @@ class LTFMSelector:
538
552
  self.policy_network_checkpoints[self.episodes] =\
539
553
  self.policy_net.state_dict()
540
554
 
555
+ if log:
556
+ logger.save_data("ActionStates_fromFit.npz")
557
+
541
558
  if monitor:
542
559
  writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
543
560
  writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
@@ -555,7 +572,7 @@ class LTFMSelector:
555
572
  else:
556
573
  return doc
557
574
 
558
- def predict(self, X_test, **kwargs):
575
+ def predict(self, X_test, log=False, **kwargs):
559
576
  '''
560
577
  Use trained agent to select features and a suitable prediction model
561
578
  to predict the target/class, given X_test.
@@ -565,6 +582,9 @@ class LTFMSelector:
565
582
  X_test : pd.DataFrame
566
583
  Test samples
567
584
 
585
+ log : bool
586
+ Log states and actions for research purposes
587
+
568
588
  Returns
569
589
  -------
570
590
  y_pred : array
@@ -583,6 +603,8 @@ class LTFMSelector:
583
603
  self.pType, self.regression_tol, self.regression_error_rounding,
584
604
  self.pModels, self.device, **kwargs
585
605
  )
606
+ if log:
607
+ logger = Logger(env, max_steps=self.max_timesteps)
586
608
 
587
609
  # Create dictionary to save information per episode
588
610
  doc_test = defaultdict(dict)
@@ -607,6 +629,9 @@ class LTFMSelector:
607
629
 
608
630
  observation, reward, terminated = env.step(action.item())
609
631
 
632
+ if log:
633
+ logger.log_step(observation, action)
634
+
610
635
  if terminated:
611
636
  next_state = None
612
637
  else:
@@ -636,6 +661,12 @@ class LTFMSelector:
636
661
  y_pred[i] = env.y_pred
637
662
  break
638
663
 
664
+ if log:
665
+ logger.log_episode()
666
+
667
+ if log:
668
+ logger.save_data("ActionStates_fromPredict.npz")
669
+
639
670
  return y_pred, doc_test
640
671
 
641
672
  def select_action(self, state, env):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ltfmselector
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Locally-Tailored Feature and Model Selector with Deep Q-Learning
5
5
  Project-URL: GitHub, https://github.com/RenZhen95/ltfmselector/
6
6
  Author-email: RenZhen95 <j-liaw@hotmail.com>
@@ -29,9 +29,11 @@ License-File: LICENSE
29
29
  Requires-Python: >=3.12
30
30
  Requires-Dist: gymnasium>=1.1.1
31
31
  Requires-Dist: matplotlib>=3.10.1
32
+ Requires-Dist: moviepy>=2.2.1
32
33
  Requires-Dist: numpy>=2.2.4
33
34
  Requires-Dist: openpyxl>=3.1.5
34
35
  Requires-Dist: pandas>=2.2.3
36
+ Requires-Dist: pygame>=2.6.1
35
37
  Requires-Dist: scikit-learn<1.6
36
38
  Requires-Dist: seaborn>=0.13.2
37
39
  Requires-Dist: tensorboard>=2.20.0
@@ -0,0 +1,10 @@
1
+ ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
2
+ ltfmselector/env.py,sha256=F0NycqUkNn-p2zC1EPdds73__G8yyMW5f9F93yPDHTA,16371
3
+ ltfmselector/logger.py,sha256=of5fgVmh1CctRE3ckjO0R_Wo6WFg8Mg1RxAI84oPKVA,1448
4
+ ltfmselector/ltfmselector.py,sha256=jBd7dj8_l7PdTtkm1lALTzc6JX9Q954eKZrSk59EfII,33151
5
+ ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
7
+ ltfmselector-0.2.2.dist-info/METADATA,sha256=8AFbJExeTNTjSUvNbuw_wDo4NKfWkj0huNY599I4i20,3079
8
+ ltfmselector-0.2.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
9
+ ltfmselector-0.2.2.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
10
+ ltfmselector-0.2.2.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
2
- ltfmselector/env.py,sha256=898o_g6-i0Rz5R-4WxZInf3xaxXHf58kPJId0KeewQM,16070
3
- ltfmselector/ltfmselector.py,sha256=zxGTLtuaoqdWbGxM8JmQES1_kGpNad1utRfDkepPoko,32329
4
- ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
6
- ltfmselector-0.2.0.dist-info/METADATA,sha256=76QDgOBLL81otMAwr9D-eNfviT5SY76Tf70-WNGIgyg,3020
7
- ltfmselector-0.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
8
- ltfmselector-0.2.0.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
9
- ltfmselector-0.2.0.dist-info/RECORD,,