tensortrade 1.0.0b0__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. tensortrade/__init__.py +23 -16
  2. tensortrade/agents/__init__.py +7 -7
  3. tensortrade/agents/a2c_agent.py +239 -237
  4. tensortrade/agents/agent.py +52 -49
  5. tensortrade/agents/dqn_agent.py +375 -202
  6. tensortrade/agents/parallel/__init__.py +5 -5
  7. tensortrade/agents/parallel/parallel_dqn_agent.py +172 -170
  8. tensortrade/agents/parallel/parallel_dqn_model.py +85 -83
  9. tensortrade/agents/parallel/parallel_dqn_optimizer.py +96 -90
  10. tensortrade/agents/parallel/parallel_dqn_trainer.py +97 -95
  11. tensortrade/agents/parallel/parallel_queue.py +95 -92
  12. tensortrade/agents/replay_memory.py +54 -52
  13. tensortrade/core/__init__.py +6 -6
  14. tensortrade/core/base.py +167 -173
  15. tensortrade/core/clock.py +48 -48
  16. tensortrade/core/component.py +129 -129
  17. tensortrade/core/context.py +182 -182
  18. tensortrade/core/exceptions.py +211 -211
  19. tensortrade/core/registry.py +45 -45
  20. tensortrade/data/__init__.py +1 -1
  21. tensortrade/data/cdd.py +152 -151
  22. tensortrade/env/__init__.py +2 -2
  23. tensortrade/env/default/__init__.py +96 -89
  24. tensortrade/env/default/actions.py +428 -399
  25. tensortrade/env/default/informers.py +14 -16
  26. tensortrade/env/default/observers.py +475 -284
  27. tensortrade/env/default/renderers.py +787 -586
  28. tensortrade/env/default/rewards.py +360 -240
  29. tensortrade/env/default/stoppers.py +33 -33
  30. tensortrade/env/generic/__init__.py +22 -22
  31. tensortrade/env/generic/components/__init__.py +13 -13
  32. tensortrade/env/generic/components/action_scheme.py +54 -54
  33. tensortrade/env/generic/components/informer.py +45 -45
  34. tensortrade/env/generic/components/observer.py +59 -59
  35. tensortrade/env/generic/components/renderer.py +86 -86
  36. tensortrade/env/generic/components/reward_scheme.py +44 -44
  37. tensortrade/env/generic/components/stopper.py +46 -46
  38. tensortrade/env/generic/environment.py +211 -163
  39. tensortrade/feed/__init__.py +5 -5
  40. tensortrade/feed/api/__init__.py +5 -5
  41. tensortrade/feed/api/boolean/__init__.py +44 -44
  42. tensortrade/feed/api/boolean/operations.py +20 -20
  43. tensortrade/feed/api/float/__init__.py +49 -48
  44. tensortrade/feed/api/float/accumulators.py +199 -199
  45. tensortrade/feed/api/float/imputation.py +40 -40
  46. tensortrade/feed/api/float/operations.py +233 -233
  47. tensortrade/feed/api/float/ordering.py +105 -105
  48. tensortrade/feed/api/float/utils.py +140 -140
  49. tensortrade/feed/api/float/window/__init__.py +3 -3
  50. tensortrade/feed/api/float/window/ewm.py +459 -452
  51. tensortrade/feed/api/float/window/expanding.py +189 -185
  52. tensortrade/feed/api/float/window/rolling.py +227 -217
  53. tensortrade/feed/api/generic/__init__.py +4 -4
  54. tensortrade/feed/api/generic/imputation.py +51 -51
  55. tensortrade/feed/api/generic/operators.py +118 -121
  56. tensortrade/feed/api/generic/reduce.py +119 -119
  57. tensortrade/feed/api/generic/warmup.py +54 -54
  58. tensortrade/feed/api/string/__init__.py +44 -43
  59. tensortrade/feed/api/string/operations.py +135 -131
  60. tensortrade/feed/core/__init__.py +3 -3
  61. tensortrade/feed/core/accessors.py +30 -30
  62. tensortrade/feed/core/base.py +634 -584
  63. tensortrade/feed/core/feed.py +120 -59
  64. tensortrade/feed/core/methods.py +37 -37
  65. tensortrade/feed/core/mixins.py +23 -23
  66. tensortrade/feed/core/operators.py +174 -174
  67. tensortrade/oms/__init__.py +2 -2
  68. tensortrade/oms/exchanges/__init__.py +1 -1
  69. tensortrade/oms/exchanges/exchange.py +176 -164
  70. tensortrade/oms/instruments/__init__.py +5 -5
  71. tensortrade/oms/instruments/exchange_pair.py +44 -44
  72. tensortrade/oms/instruments/instrument.py +161 -161
  73. tensortrade/oms/instruments/quantity.py +321 -318
  74. tensortrade/oms/instruments/trading_pair.py +58 -58
  75. tensortrade/oms/orders/__init__.py +13 -13
  76. tensortrade/oms/orders/broker.py +129 -125
  77. tensortrade/oms/orders/create.py +312 -312
  78. tensortrade/oms/orders/criteria.py +218 -218
  79. tensortrade/oms/orders/order.py +368 -368
  80. tensortrade/oms/orders/order_listener.py +62 -62
  81. tensortrade/oms/orders/order_spec.py +102 -102
  82. tensortrade/oms/orders/trade.py +159 -159
  83. tensortrade/oms/services/__init__.py +2 -2
  84. tensortrade/oms/services/execution/__init__.py +4 -4
  85. tensortrade/oms/services/execution/simulated.py +197 -183
  86. tensortrade/oms/services/slippage/__init__.py +21 -21
  87. tensortrade/oms/services/slippage/random_slippage_model.py +56 -56
  88. tensortrade/oms/services/slippage/slippage_model.py +46 -46
  89. tensortrade/oms/wallets/__init__.py +20 -20
  90. tensortrade/oms/wallets/ledger.py +92 -92
  91. tensortrade/oms/wallets/portfolio.py +330 -329
  92. tensortrade/oms/wallets/wallet.py +376 -365
  93. tensortrade/stochastic/__init__.py +12 -12
  94. tensortrade/stochastic/processes/brownian_motion.py +55 -55
  95. tensortrade/stochastic/processes/cox.py +103 -103
  96. tensortrade/stochastic/processes/fbm.py +88 -88
  97. tensortrade/stochastic/processes/gbm.py +129 -129
  98. tensortrade/stochastic/processes/heston.py +281 -281
  99. tensortrade/stochastic/processes/merton.py +91 -91
  100. tensortrade/stochastic/processes/ornstein_uhlenbeck.py +113 -113
  101. tensortrade/stochastic/utils/__init__.py +2 -2
  102. tensortrade/stochastic/utils/helpers.py +180 -179
  103. tensortrade/stochastic/utils/parameters.py +172 -172
  104. tensortrade/version.py +1 -1
  105. tensortrade-1.0.4.dist-info/METADATA +65 -0
  106. tensortrade-1.0.4.dist-info/RECORD +114 -0
  107. {tensortrade-1.0.0b0.dist-info → tensortrade-1.0.4.dist-info}/WHEEL +1 -1
  108. {tensortrade-1.0.0b0.dist-info → tensortrade-1.0.4.dist-info/licenses}/LICENSE +200 -200
  109. tensortrade-1.0.0b0.dist-info/METADATA +0 -74
  110. tensortrade-1.0.0b0.dist-info/RECORD +0 -114
  111. {tensortrade-1.0.0b0.dist-info → tensortrade-1.0.4.dist-info}/top_level.txt +0 -0
tensortrade/__init__.py CHANGED
@@ -1,16 +1,23 @@
1
-
2
- from . import core
3
- from . import data
4
- from . import feed
5
- from tensortrade.oms import (
6
- orders,
7
- wallets,
8
- instruments,
9
- exchanges,
10
- services
11
- )
12
- from . import env
13
- from . import stochastic
14
- from . import agents
15
-
16
- from .version import __version__
1
+ import sys
2
+
3
+ if sys.version_info < (3, 12):
4
+ raise RuntimeError(
5
+ f"TensorTrade requires Python 3.12 or higher. "
6
+ f"You are using Python {sys.version_info.major}.{sys.version_info.minor}."
7
+ )
8
+
9
+ from . import core
10
+ from . import data
11
+ from . import feed
12
+ from tensortrade.oms import (
13
+ orders,
14
+ wallets,
15
+ instruments,
16
+ exchanges,
17
+ services
18
+ )
19
+ from . import env
20
+ from . import stochastic
21
+ from . import agents
22
+
23
+ from .version import __version__
@@ -1,7 +1,7 @@
1
- from .agent import Agent
2
- from .replay_memory import ReplayMemory
3
-
4
- from .dqn_agent import DQNAgent, DQNTransition
5
- from .a2c_agent import A2CAgent, A2CTransition
6
-
7
- from .parallel import ParallelDQNAgent
1
+ from .agent import Agent
2
+ from .replay_memory import ReplayMemory
3
+
4
+ from .dqn_agent import DQNAgent, DQNTransition
5
+ from .a2c_agent import A2CAgent, A2CTransition
6
+
7
+ from .parallel import ParallelDQNAgent
@@ -1,237 +1,239 @@
1
- # Copyright 2019 The TensorTrade Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License
14
-
15
-
16
- """
17
- References:
18
- - http://inoryy.com/post/tensorflow2-deep-reinforcement-learning/#agent-interface
19
- """
20
-
21
-
22
- import random
23
- import numpy as np
24
- import tensorflow as tf
25
-
26
- from collections import namedtuple
27
-
28
- from tensortrade.agents import Agent, ReplayMemory
29
-
30
- A2CTransition = namedtuple('A2CTransition', ['state', 'action', 'reward', 'done', 'value'])
31
-
32
-
33
- class A2CAgent(Agent):
34
-
35
- def __init__(self,
36
- env: 'TradingEnvironment',
37
- shared_network: tf.keras.Model = None,
38
- actor_network: tf.keras.Model = None,
39
- critic_network: tf.keras.Model = None):
40
- self.env = env
41
- self.n_actions = env.action_space.n
42
- self.observation_shape = env.observation_space.shape
43
-
44
- self.shared_network = shared_network or self._build_shared_network()
45
- self.actor_network = actor_network or self._build_actor_network()
46
- self.critic_network = critic_network or self._build_critic_network()
47
-
48
- self.env.agent_id = self.id
49
-
50
- def _build_shared_network(self):
51
- network = tf.keras.Sequential([
52
- tf.keras.layers.InputLayer(input_shape=self.observation_shape),
53
- tf.keras.layers.Conv1D(filters=64, kernel_size=6, padding="same", activation="tanh"),
54
- tf.keras.layers.MaxPooling1D(pool_size=2),
55
- tf.keras.layers.Conv1D(filters=32, kernel_size=3, padding="same", activation="tanh"),
56
- tf.keras.layers.MaxPooling1D(pool_size=2),
57
- tf.keras.layers.Flatten()
58
- ])
59
-
60
- return network
61
-
62
- def _build_actor_network(self):
63
- actor_head = tf.keras.Sequential([
64
- tf.keras.layers.Dense(50, activation='relu'),
65
- tf.keras.layers.Dense(self.n_actions, activation='relu')
66
- ])
67
-
68
- return tf.keras.Sequential([self.shared_network, actor_head])
69
-
70
- def _build_critic_network(self):
71
- critic_head = tf.keras.Sequential([
72
- tf.keras.layers.Dense(50, activation='relu'),
73
- tf.keras.layers.Dense(25, activation='relu'),
74
- tf.keras.layers.Dense(1, activation='relu')
75
- ])
76
-
77
- return tf.keras.Sequential([self.shared_network, critic_head])
78
-
79
- def restore(self, path: str, **kwargs):
80
- actor_filename: str = kwargs.get('actor_filename', None)
81
- critic_filename: str = kwargs.get('critic_filename', None)
82
-
83
- if not actor_filename or not critic_filename:
84
- raise ValueError(
85
- 'The `restore` method requires a directory `path`, a `critic_filename`, and an `actor_filename`.')
86
-
87
- self.actor_network = tf.keras.models.load_model(path + actor_filename)
88
- self.critic_network = tf.keras.models.load_model(path + critic_filename)
89
-
90
- def save(self, path: str, **kwargs):
91
- episode: int = kwargs.get('episode', None)
92
-
93
- if episode:
94
- suffix = self.id + "__" + str(episode).zfill(3) + ".hdf5"
95
- actor_filename = "actor_network__" + suffix
96
- critic_filename = "critic_network__" + suffix
97
- else:
98
- actor_filename = "actor_network__" + self.id + ".hdf5"
99
- critic_filename = "critic_network__" + self.id + ".hdf5"
100
-
101
- self.actor_network.save(path + actor_filename)
102
- self.critic_network.save(path + critic_filename)
103
-
104
- def get_action(self, state: np.ndarray, **kwargs) -> int:
105
- threshold: float = kwargs.get('threshold', 0)
106
-
107
- rand = random.random()
108
-
109
- if rand < threshold:
110
- return np.random.choice(self.n_actions)
111
- else:
112
- logits = self.actor_network(state[None, :], training=False)
113
- return tf.squeeze(tf.squeeze(tf.random.categorical(logits, 1), axis=-1), axis=-1)
114
-
115
- def _apply_gradient_descent(self,
116
- memory: ReplayMemory,
117
- batch_size: int,
118
- learning_rate: float,
119
- discount_factor: float,
120
- entropy_c: float,):
121
- huber_loss = tf.keras.losses.Huber()
122
- wsce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
123
- optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
124
-
125
- transitions = memory.tail(batch_size)
126
- batch = A2CTransition(*zip(*transitions))
127
-
128
- states = tf.convert_to_tensor(batch.state)
129
- actions = tf.convert_to_tensor(batch.action)
130
- rewards = tf.convert_to_tensor(batch.reward, dtype=tf.float32)
131
- dones = tf.convert_to_tensor(batch.done)
132
- values = tf.convert_to_tensor(batch.value)
133
-
134
- returns = []
135
- exp_weighted_return = 0
136
-
137
- for reward, done in zip(rewards[::-1], dones[::-1]):
138
- exp_weighted_return = reward + discount_factor * exp_weighted_return * (1 - int(done))
139
- returns += [exp_weighted_return]
140
-
141
- returns = returns[::-1]
142
-
143
- with tf.GradientTape() as tape:
144
- state_values = self.critic_network(states)
145
- critic_loss_value = huber_loss(returns, state_values)
146
-
147
- gradients = tape.gradient(critic_loss_value, self.critic_network.trainable_variables)
148
- optimizer.apply_gradients(zip(gradients, self.critic_network.trainable_variables))
149
-
150
- with tf.GradientTape() as tape:
151
- returns = tf.reshape(returns, [batch_size, 1])
152
- advantages = returns - values
153
-
154
- actions = tf.cast(actions, tf.int32)
155
- logits = self.actor_network(states)
156
- policy_loss_value = wsce_loss(actions, logits, sample_weight=advantages)
157
-
158
- probs = tf.nn.softmax(logits)
159
- entropy_loss_value = tf.keras.losses.categorical_crossentropy(probs, probs)
160
- policy_total_loss_value = policy_loss_value - entropy_c * entropy_loss_value
161
-
162
- gradients = tape.gradient(policy_total_loss_value,
163
- self.actor_network.trainable_variables)
164
- optimizer.apply_gradients(zip(gradients, self.actor_network.trainable_variables))
165
-
166
- def train(self,
167
- n_steps: int = None,
168
- n_episodes: int = None,
169
- save_every: int = None,
170
- save_path: str = None,
171
- callback: callable = None,
172
- **kwargs) -> float:
173
- batch_size: int = kwargs.get('batch_size', 128)
174
- discount_factor: float = kwargs.get('discount_factor', 0.9999)
175
- learning_rate: float = kwargs.get('learning_rate', 0.0001)
176
- eps_start: float = kwargs.get('eps_start', 0.9)
177
- eps_end: float = kwargs.get('eps_end', 0.05)
178
- eps_decay_steps: int = kwargs.get('eps_decay_steps', 200)
179
- entropy_c: int = kwargs.get('entropy_c', 0.0001)
180
- memory_capacity: int = kwargs.get('memory_capacity', 1000)
181
-
182
- memory = ReplayMemory(memory_capacity, transition_type=A2CTransition)
183
- episode = 0
184
- steps_done = 0
185
- total_reward = 0
186
- stop_training = False
187
-
188
- if n_steps and not n_episodes:
189
- n_episodes = np.iinfo(np.int32).max
190
-
191
- print('==== AGENT ID: {} ===='.format(self.id))
192
-
193
- while episode < n_episodes and not stop_training:
194
- state = self.env.reset()
195
- done = False
196
-
197
- print('==== EPISODE ID ({}/{}): {} ===='.format(episode + 1,
198
- n_episodes,
199
- self.env.episode_id))
200
-
201
- while not done:
202
- threshold = eps_end + (eps_start - eps_end) * np.exp(-steps_done / eps_decay_steps)
203
- action = self.get_action(state, threshold=threshold)
204
- next_state, reward, done, _ = self.env.step(action)
205
-
206
- value = self.critic_network(state[None, :], training=False)
207
- value = tf.squeeze(value, axis=-1)
208
-
209
- memory.push(state, action, reward, done, value)
210
-
211
- state = next_state
212
- total_reward += reward
213
- steps_done += 1
214
-
215
- if len(memory) < batch_size:
216
- continue
217
-
218
- self._apply_gradient_descent(memory,
219
- batch_size,
220
- learning_rate,
221
- discount_factor,
222
- entropy_c)
223
-
224
- if n_steps and steps_done >= n_steps:
225
- done = True
226
- stop_training = True
227
-
228
- is_checkpoint = save_every and episode % save_every == 0
229
-
230
- if save_path and (is_checkpoint or episode == n_episodes):
231
- self.save(save_path, episode=episode)
232
-
233
- episode += 1
234
-
235
- mean_reward = total_reward / steps_done
236
-
237
- return mean_reward
1
+ # Copyright 2019 The TensorTrade Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+
16
+ """
17
+ References:
18
+ - http://inoryy.com/post/tensorflow2-deep-reinforcement-learning/#agent-interface
19
+ """
20
+
21
+ from deprecated import deprecated
22
+ import random
23
+ import numpy as np
24
+ import tensorflow as tf
25
+
26
+ from collections import namedtuple
27
+
28
+ from tensortrade.agents import Agent, ReplayMemory
29
+ from datetime import datetime
30
+
31
+ A2CTransition = namedtuple('A2CTransition', ['state', 'action', 'reward', 'done', 'value'])
32
+
33
+
34
+ @deprecated(version='1.0.4', reason="Builtin agents are being deprecated in favor of external implementations (ie: Ray)")
35
+ class A2CAgent(Agent):
36
+
37
+ def __init__(self,
38
+ env: 'TradingEnvironment',
39
+ shared_network: tf.keras.Model = None,
40
+ actor_network: tf.keras.Model = None,
41
+ critic_network: tf.keras.Model = None):
42
+ self.env = env
43
+ self.n_actions = env.action_space.n
44
+ self.observation_shape = env.observation_space.shape
45
+
46
+ self.shared_network = shared_network or self._build_shared_network()
47
+ self.actor_network = actor_network or self._build_actor_network()
48
+ self.critic_network = critic_network or self._build_critic_network()
49
+
50
+ self.env.agent_id = self.id
51
+
52
+ def _build_shared_network(self):
53
+ network = tf.keras.Sequential([
54
+ tf.keras.layers.InputLayer(input_shape=self.observation_shape),
55
+ tf.keras.layers.Conv1D(filters=64, kernel_size=6, padding="same", activation="tanh"),
56
+ tf.keras.layers.MaxPooling1D(pool_size=2),
57
+ tf.keras.layers.Conv1D(filters=32, kernel_size=3, padding="same", activation="tanh"),
58
+ tf.keras.layers.MaxPooling1D(pool_size=2),
59
+ tf.keras.layers.Flatten()
60
+ ])
61
+
62
+ return network
63
+
64
+ def _build_actor_network(self):
65
+ actor_head = tf.keras.Sequential([
66
+ tf.keras.layers.Dense(50, activation='relu'),
67
+ tf.keras.layers.Dense(self.n_actions, activation='relu')
68
+ ])
69
+
70
+ return tf.keras.Sequential([self.shared_network, actor_head])
71
+
72
+ def _build_critic_network(self):
73
+ critic_head = tf.keras.Sequential([
74
+ tf.keras.layers.Dense(50, activation='relu'),
75
+ tf.keras.layers.Dense(25, activation='relu'),
76
+ tf.keras.layers.Dense(1, activation='relu')
77
+ ])
78
+
79
+ return tf.keras.Sequential([self.shared_network, critic_head])
80
+
81
+ def restore(self, path: str, **kwargs):
82
+ actor_filename: str = kwargs.get('actor_filename', None)
83
+ critic_filename: str = kwargs.get('critic_filename', None)
84
+
85
+ if not actor_filename or not critic_filename:
86
+ raise ValueError(
87
+ 'The `restore` method requires a directory `path`, a `critic_filename`, and an `actor_filename`.')
88
+
89
+ self.actor_network = tf.keras.models.load_model(path + actor_filename)
90
+ self.critic_network = tf.keras.models.load_model(path + critic_filename)
91
+
92
+ def save(self, path: str, **kwargs):
93
+ episode: int = kwargs.get('episode', None)
94
+
95
+ if episode:
96
+ suffix = self.id[:7] + "__" + datetime.now().strftime("%Y%m%d_%H%M%S") + ".hdf5"
97
+ actor_filename = "actor_network__" + suffix
98
+ critic_filename = "critic_network__" + suffix
99
+ else:
100
+ actor_filename = "actor_network__" + self.id[:7] + "__" + datetime.now().strftime("%Y%m%d_%H%M%S") + ".hdf5"
101
+ critic_filename = "critic_network__" + self.id[:7] + "__" + datetime.now().strftime("%Y%m%d_%H%M%S") + ".hdf5"
102
+
103
+ self.actor_network.save(path + actor_filename)
104
+ self.critic_network.save(path + critic_filename)
105
+
106
+ def get_action(self, state: np.ndarray, **kwargs) -> int:
107
+ threshold: float = kwargs.get('threshold', 0)
108
+
109
+ rand = random.random()
110
+
111
+ if rand < threshold:
112
+ return np.random.choice(self.n_actions)
113
+ else:
114
+ logits = self.actor_network(state[None, :], training=False)
115
+ return tf.squeeze(tf.squeeze(tf.random.categorical(logits, 1), axis=-1), axis=-1)
116
+
117
+ def _apply_gradient_descent(self,
118
+ memory: ReplayMemory,
119
+ batch_size: int,
120
+ learning_rate: float,
121
+ discount_factor: float,
122
+ entropy_c: float,):
123
+ huber_loss = tf.keras.losses.Huber()
124
+ wsce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
125
+ optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
126
+
127
+ transitions = memory.tail(batch_size)
128
+ batch = A2CTransition(*zip(*transitions))
129
+
130
+ states = tf.convert_to_tensor(batch.state)
131
+ actions = tf.convert_to_tensor(batch.action)
132
+ rewards = tf.convert_to_tensor(batch.reward, dtype=tf.float32)
133
+ dones = tf.convert_to_tensor(batch.done)
134
+ values = tf.convert_to_tensor(batch.value)
135
+
136
+ returns = []
137
+ exp_weighted_return = 0
138
+
139
+ for reward, done in zip(rewards[::-1], dones[::-1]):
140
+ exp_weighted_return = reward + discount_factor * exp_weighted_return * (1 - int(done))
141
+ returns += [exp_weighted_return]
142
+
143
+ returns = returns[::-1]
144
+
145
+ with tf.GradientTape() as tape:
146
+ state_values = self.critic_network(states)
147
+ critic_loss_value = huber_loss(returns, state_values)
148
+
149
+ gradients = tape.gradient(critic_loss_value, self.critic_network.trainable_variables)
150
+ optimizer.apply_gradients(zip(gradients, self.critic_network.trainable_variables))
151
+
152
+ with tf.GradientTape() as tape:
153
+ returns = tf.reshape(returns, [batch_size, 1])
154
+ advantages = returns - values
155
+
156
+ actions = tf.cast(actions, tf.int32)
157
+ logits = self.actor_network(states)
158
+ policy_loss_value = wsce_loss(actions, logits, sample_weight=advantages)
159
+
160
+ probs = tf.nn.softmax(logits)
161
+ entropy_loss_value = tf.keras.losses.categorical_crossentropy(probs, probs)
162
+ policy_total_loss_value = policy_loss_value - entropy_c * entropy_loss_value
163
+
164
+ gradients = tape.gradient(policy_total_loss_value,
165
+ self.actor_network.trainable_variables)
166
+ optimizer.apply_gradients(zip(gradients, self.actor_network.trainable_variables))
167
+
168
+ def train(self,
169
+ n_steps: int = None,
170
+ n_episodes: int = None,
171
+ save_every: int = None,
172
+ save_path: str = None,
173
+ callback: callable = None,
174
+ **kwargs) -> float:
175
+ batch_size: int = kwargs.get('batch_size', 128)
176
+ discount_factor: float = kwargs.get('discount_factor', 0.9999)
177
+ learning_rate: float = kwargs.get('learning_rate', 0.0001)
178
+ eps_start: float = kwargs.get('eps_start', 0.9)
179
+ eps_end: float = kwargs.get('eps_end', 0.05)
180
+ eps_decay_steps: int = kwargs.get('eps_decay_steps', 200)
181
+ entropy_c: int = kwargs.get('entropy_c', 0.0001)
182
+ memory_capacity: int = kwargs.get('memory_capacity', 1000)
183
+
184
+ memory = ReplayMemory(memory_capacity, transition_type=A2CTransition)
185
+ episode = 0
186
+ steps_done = 0
187
+ total_reward = 0
188
+ stop_training = False
189
+
190
+ if n_steps and not n_episodes:
191
+ n_episodes = np.iinfo(np.int32).max
192
+
193
+ print('==== AGENT ID: {} ===='.format(self.id))
194
+
195
+ while episode < n_episodes and not stop_training:
196
+ state = self.env.reset()
197
+ done = False
198
+
199
+ print('==== EPISODE ID ({}/{}): {} ===='.format(episode + 1,
200
+ n_episodes,
201
+ self.env.episode_id))
202
+
203
+ while not done:
204
+ threshold = eps_end + (eps_start - eps_end) * np.exp(-steps_done / eps_decay_steps)
205
+ action = self.get_action(state, threshold=threshold)
206
+ next_state, reward, done, _ = self.env.step(action)
207
+
208
+ value = self.critic_network(state[None, :], training=False)
209
+ value = tf.squeeze(value, axis=-1)
210
+
211
+ memory.push(state, action, reward, done, value)
212
+
213
+ state = next_state
214
+ total_reward += reward
215
+ steps_done += 1
216
+
217
+ if len(memory) < batch_size:
218
+ continue
219
+
220
+ self._apply_gradient_descent(memory,
221
+ batch_size,
222
+ learning_rate,
223
+ discount_factor,
224
+ entropy_c)
225
+
226
+ if n_steps and steps_done >= n_steps:
227
+ done = True
228
+ stop_training = True
229
+
230
+ is_checkpoint = save_every and episode % save_every == 0
231
+
232
+ if save_path and (is_checkpoint or episode == n_episodes):
233
+ self.save(save_path, episode=episode)
234
+
235
+ episode += 1
236
+
237
+ mean_reward = total_reward / steps_done
238
+
239
+ return mean_reward
@@ -1,49 +1,52 @@
1
- # Copyright 2019 The TensorTrade Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License
14
-
15
-
16
- import numpy as np
17
-
18
- from abc import ABCMeta, abstractmethod
19
-
20
- from tensortrade.core import Identifiable
21
-
22
-
23
- class Agent(Identifiable, metaclass=ABCMeta):
24
-
25
- @abstractmethod
26
- def restore(self, path: str, **kwargs):
27
- """Restore the agent from the file specified in `path`."""
28
- raise NotImplementedError()
29
-
30
- @abstractmethod
31
- def save(self, path: str, **kwargs):
32
- """Save the agent to the directory specified in `path`."""
33
- raise NotImplementedError()
34
-
35
- @abstractmethod
36
- def get_action(self, state: np.ndarray, **kwargs) -> int:
37
- """Get an action for a specific state in the environment."""
38
- raise NotImplementedError()
39
-
40
- @abstractmethod
41
- def train(self,
42
- n_steps: int = None,
43
- n_episodes: int = 10000,
44
- save_every: int = None,
45
- save_path: str = None,
46
- callback: callable = None,
47
- **kwargs) -> float:
48
- """Train the agent in the environment and return the mean reward."""
49
- raise NotImplementedError()
1
+ # Copyright 2019 The TensorTrade Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+
16
+ from deprecated import deprecated
17
+
18
+ import numpy as np
19
+
20
+ from abc import ABCMeta, abstractmethod
21
+
22
+ from tensortrade.core import Identifiable
23
+
24
+
25
+ @deprecated(version='1.0.4', reason="Builtin agents are being deprecated in favor of external implementations (ie: Ray)")
26
+ class Agent(Identifiable, metaclass=ABCMeta):
27
+
28
+ @abstractmethod
29
+ def restore(self, path: str, **kwargs):
30
+ """Restore the agent from the file specified in `path`."""
31
+ raise NotImplementedError()
32
+
33
+ @abstractmethod
34
+ def save(self, path: str, **kwargs):
35
+ """Save the agent to the directory specified in `path`."""
36
+ raise NotImplementedError()
37
+
38
+ @abstractmethod
39
+ def get_action(self, state: np.ndarray, **kwargs) -> int:
40
+ """Get an action for a specific state in the environment."""
41
+ raise NotImplementedError()
42
+
43
+ @abstractmethod
44
+ def train(self,
45
+ n_steps: int = None,
46
+ n_episodes: int = 10000,
47
+ save_every: int = None,
48
+ save_path: str = None,
49
+ callback: callable = None,
50
+ **kwargs) -> float:
51
+ """Train the agent in the environment and return the mean reward."""
52
+ raise NotImplementedError()