ltfmselector 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ltfmselector/env.py +13 -3
- ltfmselector/logger.py +39 -0
- ltfmselector/ltfmselector.py +46 -15
- {ltfmselector-0.2.0.dist-info → ltfmselector-0.2.2.dist-info}/METADATA +3 -1
- ltfmselector-0.2.2.dist-info/RECORD +10 -0
- ltfmselector-0.2.0.dist-info/RECORD +0 -9
- {ltfmselector-0.2.0.dist-info → ltfmselector-0.2.2.dist-info}/WHEEL +0 -0
- {ltfmselector-0.2.0.dist-info → ltfmselector-0.2.2.dist-info}/licenses/LICENSE +0 -0
ltfmselector/env.py
CHANGED
|
@@ -364,6 +364,7 @@ class Environment:
|
|
|
364
364
|
# Get number of total recruited features
|
|
365
365
|
nFSubset = (self.get_feature_mask()).sum()
|
|
366
366
|
|
|
367
|
+
# DEV:: If more than 10 statements, implement dictionary instead
|
|
367
368
|
if self.fQueryFunction == "step":
|
|
368
369
|
return self.get_fQueryCostStep(nFSubset)
|
|
369
370
|
elif self.fQueryFunction == "linear":
|
|
@@ -380,17 +381,26 @@ class Environment:
|
|
|
380
381
|
|
|
381
382
|
def get_fQueryCostLinear(self, _nFSubset):
|
|
382
383
|
'''Linear function for querying feature'''
|
|
383
|
-
|
|
384
|
+
_qC = max(
|
|
384
385
|
self.fQueryCost,
|
|
385
386
|
self.fQueryCost + self.fRate*(_nFSubset-self.fThreshold)
|
|
386
387
|
)
|
|
388
|
+
if not self.fCap is None:
|
|
389
|
+
return min(self.fCap, _qC)
|
|
390
|
+
else:
|
|
391
|
+
return _qC
|
|
387
392
|
|
|
388
393
|
def get_fQueryCostQuadratic(self, _nFSubset):
|
|
389
394
|
'''Quadratic function for querying feature'''
|
|
390
395
|
if _nFSubset > self.fThreshold:
|
|
391
|
-
|
|
396
|
+
_qC = self.fQueryCost + self.fRate*(_nFSubset-self.fThreshold)**2
|
|
392
397
|
else:
|
|
393
|
-
|
|
398
|
+
_qC = self.fQueryCost
|
|
399
|
+
|
|
400
|
+
if not self.fCap is None:
|
|
401
|
+
return min(self.fCap, _qC)
|
|
402
|
+
else:
|
|
403
|
+
return _qC
|
|
394
404
|
|
|
395
405
|
def get_feature_mask(self):
|
|
396
406
|
'''
|
ltfmselector/logger.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
class Logger:
|
|
4
|
+
def __init__(self, env, max_steps=500):
|
|
5
|
+
self.state_dim = env.X.shape[1]*2 + 1
|
|
6
|
+
self.max_steps = max_steps
|
|
7
|
+
|
|
8
|
+
# Pre-allocate the "scratchpad" for the current episode
|
|
9
|
+
self.s_buffer = np.zeros((self.max_steps, self.state_dim), dtype=np.float32)
|
|
10
|
+
self.a_buffer = np.zeros((self.max_steps,), dtype=np.int32)
|
|
11
|
+
self.ptr = 0
|
|
12
|
+
|
|
13
|
+
# Final storage lists
|
|
14
|
+
self.all_states = []
|
|
15
|
+
self.all_actions = []
|
|
16
|
+
|
|
17
|
+
def log_step(self, obs, action):
|
|
18
|
+
if self.ptr < self.max_steps:
|
|
19
|
+
self.s_buffer[self.ptr] = obs
|
|
20
|
+
self.a_buffer[self.ptr] = action
|
|
21
|
+
self.ptr += 1
|
|
22
|
+
|
|
23
|
+
def log_episode(self):
|
|
24
|
+
"""Internal helper to copy the current buffer into the main list."""
|
|
25
|
+
if self.ptr > 0:
|
|
26
|
+
self.all_states.append(self.s_buffer[:self.ptr].copy())
|
|
27
|
+
self.all_actions.append(self.a_buffer[:self.ptr].copy())
|
|
28
|
+
self.ptr = 0
|
|
29
|
+
|
|
30
|
+
def save_data(self, filename):
|
|
31
|
+
"""Saves all completed and currently-running episodes to disk."""
|
|
32
|
+
# Capture the current episode if it's mid-run
|
|
33
|
+
self.log_episode()
|
|
34
|
+
np.savez_compressed(
|
|
35
|
+
filename,
|
|
36
|
+
states=np.array(self.all_states, dtype=object),
|
|
37
|
+
actions=np.array(self.all_actions, dtype=object)
|
|
38
|
+
)
|
|
39
|
+
print(f"Saved {len(self.all_states)} episodes (states + actions) to {filename}")
|
ltfmselector/ltfmselector.py
CHANGED
|
@@ -11,6 +11,7 @@ import numpy as np
|
|
|
11
11
|
import pandas as pd
|
|
12
12
|
from collections import defaultdict
|
|
13
13
|
|
|
14
|
+
from .logger import Logger
|
|
14
15
|
from .env import Environment
|
|
15
16
|
from .utils import ReplayMemory, DQN, Transition
|
|
16
17
|
|
|
@@ -110,28 +111,29 @@ class LTFMSelector:
|
|
|
110
111
|
Cost of querying a feature.
|
|
111
112
|
|
|
112
113
|
fQueryFunction : None or {'step', 'linear', 'quadratic'}
|
|
113
|
-
User can also decide to progressively increase the cost of
|
|
114
|
+
User can also decide to progressively increase the cost of
|
|
114
115
|
querying features in the following manner:
|
|
115
|
-
'step' :
|
|
116
|
-
Every additional feature adds a fixed constant, determined
|
|
116
|
+
'step' :
|
|
117
|
+
Every additional feature adds a fixed constant, determined
|
|
117
118
|
by user.
|
|
118
|
-
'linear' :
|
|
119
|
-
Cost of every additional feature linearly increases according
|
|
119
|
+
'linear' :
|
|
120
|
+
Cost of every additional feature linearly increases according
|
|
120
121
|
to user-defined gradient
|
|
121
|
-
'quadratic' :
|
|
122
|
-
Cost of every additional feature increases quadratically,
|
|
122
|
+
'quadratic' :
|
|
123
|
+
Cost of every additional feature increases quadratically,
|
|
123
124
|
according to a user-defined rate
|
|
124
125
|
|
|
125
126
|
fThreshold : None or int
|
|
126
127
|
If `fQueryFunction == {'step', 'linear', 'quadratic', 'exponential'}`
|
|
127
|
-
Threshold of number of features, before cost of recruiting
|
|
128
|
+
Threshold of number of features, before cost of recruiting
|
|
128
129
|
increases
|
|
129
|
-
|
|
130
|
+
|
|
130
131
|
fCap : None or float
|
|
131
|
-
If `fQueryFunction == {'step'}`, upper
|
|
132
|
+
If `fQueryFunction == {'step', 'linear', 'quadratic'}`, upper
|
|
133
|
+
limit of penalty
|
|
132
134
|
|
|
133
135
|
fRate : None or float
|
|
134
|
-
If `fQueryFunction == {'linear', 'quadratic'
|
|
136
|
+
If `fQueryFunction == {'linear', 'quadratic'}`, rate of
|
|
135
137
|
individual cost functions
|
|
136
138
|
|
|
137
139
|
mQueryCost : float
|
|
@@ -283,7 +285,7 @@ class LTFMSelector:
|
|
|
283
285
|
def fit(
|
|
284
286
|
self, X, y, loss_function='mse', sample_weight=None,
|
|
285
287
|
agent_neuralnetwork=None, lr=1e-5, returnQ=False,
|
|
286
|
-
monitor=False, background_dataset=None, **kwargs
|
|
288
|
+
monitor=False, log=False, background_dataset=None, **kwargs
|
|
287
289
|
):
|
|
288
290
|
'''
|
|
289
291
|
Initializes the environment and agent, then trains the agent to select
|
|
@@ -322,15 +324,18 @@ class LTFMSelector:
|
|
|
322
324
|
|
|
323
325
|
returnQ : bool
|
|
324
326
|
Return average computed action-value functions and rewards of
|
|
325
|
-
the sampled batches, as a (<total_iterations>, 3) matrix. The
|
|
327
|
+
the sampled batches, as a (<total_iterations>, 3) matrix. The
|
|
326
328
|
columns correspond to the averaged Q, reward, and target functions.
|
|
327
329
|
|
|
328
330
|
monitor : bool
|
|
329
331
|
Monitor training process using a TensorBoard.
|
|
330
332
|
|
|
331
|
-
Run `tensorboard --logdir=runs` in the terminal to monitor the
|
|
333
|
+
Run `tensorboard --logdir=runs` in the terminal to monitor the
|
|
332
334
|
progression of the action-value function.
|
|
333
335
|
|
|
336
|
+
log : bool
|
|
337
|
+
Log states and actions for research purposes
|
|
338
|
+
|
|
334
339
|
background_dataset : None or pd.DataFrame
|
|
335
340
|
If None, numerical features will be assumed when computing the
|
|
336
341
|
background dataset.
|
|
@@ -402,6 +407,9 @@ class LTFMSelector:
|
|
|
402
407
|
)
|
|
403
408
|
env.reset()
|
|
404
409
|
|
|
410
|
+
if log:
|
|
411
|
+
logger = Logger(env, max_steps=self.max_timesteps)
|
|
412
|
+
|
|
405
413
|
# Initializing length of state and actions as public fields for
|
|
406
414
|
# loading the model later
|
|
407
415
|
self.state_length = len(env.state)
|
|
@@ -464,6 +472,9 @@ class LTFMSelector:
|
|
|
464
472
|
# - reward
|
|
465
473
|
observation, reward, terminated = env.step(action.item())
|
|
466
474
|
|
|
475
|
+
if log:
|
|
476
|
+
logger.log_step(observation, action)
|
|
477
|
+
|
|
467
478
|
if terminated:
|
|
468
479
|
next_state = None
|
|
469
480
|
else:
|
|
@@ -528,6 +539,9 @@ class LTFMSelector:
|
|
|
528
539
|
)
|
|
529
540
|
break
|
|
530
541
|
|
|
542
|
+
if log:
|
|
543
|
+
logger.log_episode()
|
|
544
|
+
|
|
531
545
|
# Saving trained policy network intermediately
|
|
532
546
|
if not self.checkpoint_interval is None:
|
|
533
547
|
if (i_episode + 1) % self.checkpoint_interval == 0:
|
|
@@ -538,6 +552,9 @@ class LTFMSelector:
|
|
|
538
552
|
self.policy_network_checkpoints[self.episodes] =\
|
|
539
553
|
self.policy_net.state_dict()
|
|
540
554
|
|
|
555
|
+
if log:
|
|
556
|
+
logger.save_data("ActionStates_fromFit.npz")
|
|
557
|
+
|
|
541
558
|
if monitor:
|
|
542
559
|
writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
|
|
543
560
|
writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
|
|
@@ -555,7 +572,7 @@ class LTFMSelector:
|
|
|
555
572
|
else:
|
|
556
573
|
return doc
|
|
557
574
|
|
|
558
|
-
def predict(self, X_test, **kwargs):
|
|
575
|
+
def predict(self, X_test, log=False, **kwargs):
|
|
559
576
|
'''
|
|
560
577
|
Use trained agent to select features and a suitable prediction model
|
|
561
578
|
to predict the target/class, given X_test.
|
|
@@ -565,6 +582,9 @@ class LTFMSelector:
|
|
|
565
582
|
X_test : pd.DataFrame
|
|
566
583
|
Test samples
|
|
567
584
|
|
|
585
|
+
log : bool
|
|
586
|
+
Log states and actions for research purposes
|
|
587
|
+
|
|
568
588
|
Returns
|
|
569
589
|
-------
|
|
570
590
|
y_pred : array
|
|
@@ -583,6 +603,8 @@ class LTFMSelector:
|
|
|
583
603
|
self.pType, self.regression_tol, self.regression_error_rounding,
|
|
584
604
|
self.pModels, self.device, **kwargs
|
|
585
605
|
)
|
|
606
|
+
if log:
|
|
607
|
+
logger = Logger(env, max_steps=self.max_timesteps)
|
|
586
608
|
|
|
587
609
|
# Create dictionary to save information per episode
|
|
588
610
|
doc_test = defaultdict(dict)
|
|
@@ -607,6 +629,9 @@ class LTFMSelector:
|
|
|
607
629
|
|
|
608
630
|
observation, reward, terminated = env.step(action.item())
|
|
609
631
|
|
|
632
|
+
if log:
|
|
633
|
+
logger.log_step(observation, action)
|
|
634
|
+
|
|
610
635
|
if terminated:
|
|
611
636
|
next_state = None
|
|
612
637
|
else:
|
|
@@ -636,6 +661,12 @@ class LTFMSelector:
|
|
|
636
661
|
y_pred[i] = env.y_pred
|
|
637
662
|
break
|
|
638
663
|
|
|
664
|
+
if log:
|
|
665
|
+
logger.log_episode()
|
|
666
|
+
|
|
667
|
+
if log:
|
|
668
|
+
logger.save_data("ActionStates_fromPredict.npz")
|
|
669
|
+
|
|
639
670
|
return y_pred, doc_test
|
|
640
671
|
|
|
641
672
|
def select_action(self, state, env):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ltfmselector
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: Locally-Tailored Feature and Model Selector with Deep Q-Learning
|
|
5
5
|
Project-URL: GitHub, https://github.com/RenZhen95/ltfmselector/
|
|
6
6
|
Author-email: RenZhen95 <j-liaw@hotmail.com>
|
|
@@ -29,9 +29,11 @@ License-File: LICENSE
|
|
|
29
29
|
Requires-Python: >=3.12
|
|
30
30
|
Requires-Dist: gymnasium>=1.1.1
|
|
31
31
|
Requires-Dist: matplotlib>=3.10.1
|
|
32
|
+
Requires-Dist: moviepy>=2.2.1
|
|
32
33
|
Requires-Dist: numpy>=2.2.4
|
|
33
34
|
Requires-Dist: openpyxl>=3.1.5
|
|
34
35
|
Requires-Dist: pandas>=2.2.3
|
|
36
|
+
Requires-Dist: pygame>=2.6.1
|
|
35
37
|
Requires-Dist: scikit-learn<1.6
|
|
36
38
|
Requires-Dist: seaborn>=0.13.2
|
|
37
39
|
Requires-Dist: tensorboard>=2.20.0
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
|
|
2
|
+
ltfmselector/env.py,sha256=F0NycqUkNn-p2zC1EPdds73__G8yyMW5f9F93yPDHTA,16371
|
|
3
|
+
ltfmselector/logger.py,sha256=of5fgVmh1CctRE3ckjO0R_Wo6WFg8Mg1RxAI84oPKVA,1448
|
|
4
|
+
ltfmselector/ltfmselector.py,sha256=jBd7dj8_l7PdTtkm1lALTzc6JX9Q954eKZrSk59EfII,33151
|
|
5
|
+
ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
+
ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
|
|
7
|
+
ltfmselector-0.2.2.dist-info/METADATA,sha256=8AFbJExeTNTjSUvNbuw_wDo4NKfWkj0huNY599I4i20,3079
|
|
8
|
+
ltfmselector-0.2.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
9
|
+
ltfmselector-0.2.2.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
|
|
10
|
+
ltfmselector-0.2.2.dist-info/RECORD,,
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
|
|
2
|
-
ltfmselector/env.py,sha256=898o_g6-i0Rz5R-4WxZInf3xaxXHf58kPJId0KeewQM,16070
|
|
3
|
-
ltfmselector/ltfmselector.py,sha256=zxGTLtuaoqdWbGxM8JmQES1_kGpNad1utRfDkepPoko,32329
|
|
4
|
-
ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
|
|
6
|
-
ltfmselector-0.2.0.dist-info/METADATA,sha256=76QDgOBLL81otMAwr9D-eNfviT5SY76Tf70-WNGIgyg,3020
|
|
7
|
-
ltfmselector-0.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
8
|
-
ltfmselector-0.2.0.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
|
|
9
|
-
ltfmselector-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|