tensortrade 1.0.0b0__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensortrade/__init__.py +23 -16
- tensortrade/agents/__init__.py +7 -7
- tensortrade/agents/a2c_agent.py +239 -237
- tensortrade/agents/agent.py +52 -49
- tensortrade/agents/dqn_agent.py +375 -202
- tensortrade/agents/parallel/__init__.py +5 -5
- tensortrade/agents/parallel/parallel_dqn_agent.py +172 -170
- tensortrade/agents/parallel/parallel_dqn_model.py +85 -83
- tensortrade/agents/parallel/parallel_dqn_optimizer.py +96 -90
- tensortrade/agents/parallel/parallel_dqn_trainer.py +97 -95
- tensortrade/agents/parallel/parallel_queue.py +95 -92
- tensortrade/agents/replay_memory.py +54 -52
- tensortrade/core/__init__.py +6 -6
- tensortrade/core/base.py +167 -173
- tensortrade/core/clock.py +48 -48
- tensortrade/core/component.py +129 -129
- tensortrade/core/context.py +182 -182
- tensortrade/core/exceptions.py +211 -211
- tensortrade/core/registry.py +45 -45
- tensortrade/data/__init__.py +1 -1
- tensortrade/data/cdd.py +152 -151
- tensortrade/env/__init__.py +2 -2
- tensortrade/env/default/__init__.py +96 -89
- tensortrade/env/default/actions.py +428 -399
- tensortrade/env/default/informers.py +14 -16
- tensortrade/env/default/observers.py +475 -284
- tensortrade/env/default/renderers.py +787 -586
- tensortrade/env/default/rewards.py +360 -240
- tensortrade/env/default/stoppers.py +33 -33
- tensortrade/env/generic/__init__.py +22 -22
- tensortrade/env/generic/components/__init__.py +13 -13
- tensortrade/env/generic/components/action_scheme.py +54 -54
- tensortrade/env/generic/components/informer.py +45 -45
- tensortrade/env/generic/components/observer.py +59 -59
- tensortrade/env/generic/components/renderer.py +86 -86
- tensortrade/env/generic/components/reward_scheme.py +44 -44
- tensortrade/env/generic/components/stopper.py +46 -46
- tensortrade/env/generic/environment.py +211 -163
- tensortrade/feed/__init__.py +5 -5
- tensortrade/feed/api/__init__.py +5 -5
- tensortrade/feed/api/boolean/__init__.py +44 -44
- tensortrade/feed/api/boolean/operations.py +20 -20
- tensortrade/feed/api/float/__init__.py +49 -48
- tensortrade/feed/api/float/accumulators.py +199 -199
- tensortrade/feed/api/float/imputation.py +40 -40
- tensortrade/feed/api/float/operations.py +233 -233
- tensortrade/feed/api/float/ordering.py +105 -105
- tensortrade/feed/api/float/utils.py +140 -140
- tensortrade/feed/api/float/window/__init__.py +3 -3
- tensortrade/feed/api/float/window/ewm.py +459 -452
- tensortrade/feed/api/float/window/expanding.py +189 -185
- tensortrade/feed/api/float/window/rolling.py +227 -217
- tensortrade/feed/api/generic/__init__.py +4 -4
- tensortrade/feed/api/generic/imputation.py +51 -51
- tensortrade/feed/api/generic/operators.py +118 -121
- tensortrade/feed/api/generic/reduce.py +119 -119
- tensortrade/feed/api/generic/warmup.py +54 -54
- tensortrade/feed/api/string/__init__.py +44 -43
- tensortrade/feed/api/string/operations.py +135 -131
- tensortrade/feed/core/__init__.py +3 -3
- tensortrade/feed/core/accessors.py +30 -30
- tensortrade/feed/core/base.py +634 -584
- tensortrade/feed/core/feed.py +120 -59
- tensortrade/feed/core/methods.py +37 -37
- tensortrade/feed/core/mixins.py +23 -23
- tensortrade/feed/core/operators.py +174 -174
- tensortrade/oms/__init__.py +2 -2
- tensortrade/oms/exchanges/__init__.py +1 -1
- tensortrade/oms/exchanges/exchange.py +176 -164
- tensortrade/oms/instruments/__init__.py +5 -5
- tensortrade/oms/instruments/exchange_pair.py +44 -44
- tensortrade/oms/instruments/instrument.py +161 -161
- tensortrade/oms/instruments/quantity.py +321 -318
- tensortrade/oms/instruments/trading_pair.py +58 -58
- tensortrade/oms/orders/__init__.py +13 -13
- tensortrade/oms/orders/broker.py +129 -125
- tensortrade/oms/orders/create.py +312 -312
- tensortrade/oms/orders/criteria.py +218 -218
- tensortrade/oms/orders/order.py +368 -368
- tensortrade/oms/orders/order_listener.py +62 -62
- tensortrade/oms/orders/order_spec.py +102 -102
- tensortrade/oms/orders/trade.py +159 -159
- tensortrade/oms/services/__init__.py +2 -2
- tensortrade/oms/services/execution/__init__.py +4 -4
- tensortrade/oms/services/execution/simulated.py +197 -183
- tensortrade/oms/services/slippage/__init__.py +21 -21
- tensortrade/oms/services/slippage/random_slippage_model.py +56 -56
- tensortrade/oms/services/slippage/slippage_model.py +46 -46
- tensortrade/oms/wallets/__init__.py +20 -20
- tensortrade/oms/wallets/ledger.py +92 -92
- tensortrade/oms/wallets/portfolio.py +330 -329
- tensortrade/oms/wallets/wallet.py +376 -365
- tensortrade/stochastic/__init__.py +12 -12
- tensortrade/stochastic/processes/brownian_motion.py +55 -55
- tensortrade/stochastic/processes/cox.py +103 -103
- tensortrade/stochastic/processes/fbm.py +88 -88
- tensortrade/stochastic/processes/gbm.py +129 -129
- tensortrade/stochastic/processes/heston.py +281 -281
- tensortrade/stochastic/processes/merton.py +91 -91
- tensortrade/stochastic/processes/ornstein_uhlenbeck.py +113 -113
- tensortrade/stochastic/utils/__init__.py +2 -2
- tensortrade/stochastic/utils/helpers.py +180 -179
- tensortrade/stochastic/utils/parameters.py +172 -172
- tensortrade/version.py +1 -1
- tensortrade-1.0.4.dist-info/METADATA +65 -0
- tensortrade-1.0.4.dist-info/RECORD +114 -0
- {tensortrade-1.0.0b0.dist-info → tensortrade-1.0.4.dist-info}/WHEEL +1 -1
- {tensortrade-1.0.0b0.dist-info → tensortrade-1.0.4.dist-info/licenses}/LICENSE +200 -200
- tensortrade-1.0.0b0.dist-info/METADATA +0 -74
- tensortrade-1.0.0b0.dist-info/RECORD +0 -114
- {tensortrade-1.0.0b0.dist-info → tensortrade-1.0.4.dist-info}/top_level.txt +0 -0
|
@@ -1,44 +1,44 @@
|
|
|
1
|
-
# Copyright 2020 The TensorTrade Authors.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License
|
|
14
|
-
|
|
15
|
-
from abc import abstractmethod
|
|
16
|
-
|
|
17
|
-
from tensortrade.core.component import Component
|
|
18
|
-
from tensortrade.core.base import TimeIndexed
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class RewardScheme(Component, TimeIndexed):
|
|
22
|
-
"""A component to compute the reward at each step of an episode."""
|
|
23
|
-
|
|
24
|
-
registered_name = "rewards"
|
|
25
|
-
|
|
26
|
-
@abstractmethod
|
|
27
|
-
def reward(self, env: 'TradingEnv') -> float:
|
|
28
|
-
"""Computes the reward for the current step of an episode.
|
|
29
|
-
|
|
30
|
-
Parameters
|
|
31
|
-
----------
|
|
32
|
-
env : `TradingEnv`
|
|
33
|
-
The trading environment
|
|
34
|
-
|
|
35
|
-
Returns
|
|
36
|
-
-------
|
|
37
|
-
float
|
|
38
|
-
The computed reward.
|
|
39
|
-
"""
|
|
40
|
-
raise NotImplementedError()
|
|
41
|
-
|
|
42
|
-
def reset(self) -> None:
|
|
43
|
-
"""Resets the reward scheme."""
|
|
44
|
-
pass
|
|
1
|
+
# Copyright 2020 The TensorTrade Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
|
|
17
|
+
from tensortrade.core.component import Component
|
|
18
|
+
from tensortrade.core.base import TimeIndexed
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RewardScheme(Component, TimeIndexed):
|
|
22
|
+
"""A component to compute the reward at each step of an episode."""
|
|
23
|
+
|
|
24
|
+
registered_name = "rewards"
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def reward(self, env: 'TradingEnv') -> float:
|
|
28
|
+
"""Computes the reward for the current step of an episode.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
env : `TradingEnv`
|
|
33
|
+
The trading environment
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
float
|
|
38
|
+
The computed reward.
|
|
39
|
+
"""
|
|
40
|
+
raise NotImplementedError()
|
|
41
|
+
|
|
42
|
+
def reset(self) -> None:
|
|
43
|
+
"""Resets the reward scheme."""
|
|
44
|
+
pass
|
|
@@ -1,46 +1,46 @@
|
|
|
1
|
-
# Copyright 2020 The TensorTrade Authors.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License
|
|
14
|
-
|
|
15
|
-
from abc import abstractmethod
|
|
16
|
-
|
|
17
|
-
from tensortrade.core.component import Component
|
|
18
|
-
from tensortrade.core.base import TimeIndexed
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class Stopper(Component, TimeIndexed):
|
|
22
|
-
"""A component for determining if the environment satisfies a defined
|
|
23
|
-
stopping criteria.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
registered_name = "stopper"
|
|
27
|
-
|
|
28
|
-
@abstractmethod
|
|
29
|
-
def stop(self, env: 'TradingEnv') -> bool:
|
|
30
|
-
"""Computes if the environment satisfies the defined stopping criteria.
|
|
31
|
-
|
|
32
|
-
Parameters
|
|
33
|
-
----------
|
|
34
|
-
env : `TradingEnv`
|
|
35
|
-
The trading environment.
|
|
36
|
-
|
|
37
|
-
Returns
|
|
38
|
-
-------
|
|
39
|
-
bool
|
|
40
|
-
If the environment should stop or continue.
|
|
41
|
-
"""
|
|
42
|
-
raise NotImplementedError()
|
|
43
|
-
|
|
44
|
-
def reset(self) -> None:
|
|
45
|
-
"""Resets the stopper."""
|
|
46
|
-
pass
|
|
1
|
+
# Copyright 2020 The TensorTrade Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
|
|
17
|
+
from tensortrade.core.component import Component
|
|
18
|
+
from tensortrade.core.base import TimeIndexed
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Stopper(Component, TimeIndexed):
|
|
22
|
+
"""A component for determining if the environment satisfies a defined
|
|
23
|
+
stopping criteria.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
registered_name = "stopper"
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def stop(self, env: 'TradingEnv') -> bool:
|
|
30
|
+
"""Computes if the environment satisfies the defined stopping criteria.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
env : `TradingEnv`
|
|
35
|
+
The trading environment.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
bool
|
|
40
|
+
If the environment should stop or continue.
|
|
41
|
+
"""
|
|
42
|
+
raise NotImplementedError()
|
|
43
|
+
|
|
44
|
+
def reset(self) -> None:
|
|
45
|
+
"""Resets the stopper."""
|
|
46
|
+
pass
|
|
@@ -1,163 +1,211 @@
|
|
|
1
|
-
# Copyright 2020 The TensorTrade Authors.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License
|
|
14
|
-
|
|
15
|
-
import uuid
|
|
16
|
-
import logging
|
|
17
|
-
|
|
18
|
-
from typing import Dict, Any, Tuple
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
import
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
from tensortrade.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
self.
|
|
75
|
-
|
|
76
|
-
self.
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
self.
|
|
82
|
-
self.
|
|
83
|
-
|
|
84
|
-
self.
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def
|
|
133
|
-
"""
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
1
|
+
# Copyright 2020 The TensorTrade Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
import uuid
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from typing import Dict, Any, Tuple, Optional
|
|
19
|
+
from random import randint
|
|
20
|
+
|
|
21
|
+
import gymnasium
|
|
22
|
+
import numpy as np
|
|
23
|
+
|
|
24
|
+
from tensortrade.core import TimeIndexed, Clock, Component
|
|
25
|
+
from tensortrade.env.generic import (
|
|
26
|
+
ActionScheme,
|
|
27
|
+
RewardScheme,
|
|
28
|
+
Observer,
|
|
29
|
+
Stopper,
|
|
30
|
+
Informer,
|
|
31
|
+
Renderer
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TradingEnv(gymnasium.Env, TimeIndexed):
|
|
36
|
+
"""A trading environment made for use with Gym-compatible reinforcement
|
|
37
|
+
learning algorithms.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
action_scheme : `ActionScheme`
|
|
42
|
+
A component for generating an action to perform at each step of the
|
|
43
|
+
environment.
|
|
44
|
+
reward_scheme : `RewardScheme`
|
|
45
|
+
A component for computing reward after each step of the environment.
|
|
46
|
+
observer : `Observer`
|
|
47
|
+
A component for generating observations after each step of the
|
|
48
|
+
environment.
|
|
49
|
+
informer : `Informer`
|
|
50
|
+
A component for providing information after each step of the
|
|
51
|
+
environment.
|
|
52
|
+
renderer : `Renderer`
|
|
53
|
+
A component for rendering the environment.
|
|
54
|
+
kwargs : keyword arguments
|
|
55
|
+
Additional keyword arguments needed to create the environment.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
agent_id: str = None
|
|
59
|
+
episode_id: str = None
|
|
60
|
+
|
|
61
|
+
def __init__(self,
|
|
62
|
+
action_scheme: ActionScheme,
|
|
63
|
+
reward_scheme: RewardScheme,
|
|
64
|
+
observer: Observer,
|
|
65
|
+
stopper: Stopper,
|
|
66
|
+
informer: Informer,
|
|
67
|
+
renderer: Renderer,
|
|
68
|
+
min_periods: int = None,
|
|
69
|
+
max_episode_steps: int = None,
|
|
70
|
+
random_start_pct: float = 0.00,
|
|
71
|
+
device: Optional[str] = None,
|
|
72
|
+
**kwargs) -> None:
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.clock = Clock()
|
|
75
|
+
|
|
76
|
+
self.action_scheme = action_scheme
|
|
77
|
+
self.reward_scheme = reward_scheme
|
|
78
|
+
self.observer = observer
|
|
79
|
+
self.stopper = stopper
|
|
80
|
+
self.informer = informer
|
|
81
|
+
self.renderer = renderer
|
|
82
|
+
self.min_periods = min_periods
|
|
83
|
+
self.max_episode_steps = max_episode_steps
|
|
84
|
+
self.random_start_pct = random_start_pct
|
|
85
|
+
self.device = device
|
|
86
|
+
|
|
87
|
+
for c in self.components.values():
|
|
88
|
+
c.clock = self.clock
|
|
89
|
+
|
|
90
|
+
self.action_space = action_scheme.action_space
|
|
91
|
+
self.observation_space = observer.observation_space
|
|
92
|
+
|
|
93
|
+
self._enable_logger = kwargs.get('enable_logger', False)
|
|
94
|
+
if self._enable_logger:
|
|
95
|
+
self.logger = logging.getLogger(kwargs.get('logger_name', __name__))
|
|
96
|
+
self.logger.setLevel(kwargs.get('log_level', logging.DEBUG))
|
|
97
|
+
|
|
98
|
+
def _ensure_numpy(self, obs: Any) -> np.ndarray:
|
|
99
|
+
"""Ensure observation is returned as numpy array for GPU compatibility.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
obs : Any
|
|
104
|
+
The observation to convert
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
np.ndarray
|
|
109
|
+
The observation as a numpy array
|
|
110
|
+
"""
|
|
111
|
+
if hasattr(obs, 'cpu'): # PyTorch tensor
|
|
112
|
+
return obs.cpu().numpy()
|
|
113
|
+
elif hasattr(obs, 'numpy'): # TensorFlow tensor
|
|
114
|
+
return obs.numpy()
|
|
115
|
+
elif isinstance(obs, np.ndarray):
|
|
116
|
+
return obs
|
|
117
|
+
else:
|
|
118
|
+
return np.array(obs)
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def components(self) -> 'Dict[str, Component]':
|
|
122
|
+
"""The components of the environment. (`Dict[str,Component]`, read-only)"""
|
|
123
|
+
return {
|
|
124
|
+
"action_scheme": self.action_scheme,
|
|
125
|
+
"reward_scheme": self.reward_scheme,
|
|
126
|
+
"observer": self.observer,
|
|
127
|
+
"stopper": self.stopper,
|
|
128
|
+
"informer": self.informer,
|
|
129
|
+
"renderer": self.renderer
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
def step(self, action: Any) -> 'Tuple[np.array, float, bool, dict]':
|
|
133
|
+
"""Makes one step through the environment.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
action : Any
|
|
138
|
+
An action to perform on the environment.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
`np.array`
|
|
143
|
+
The observation of the environment after the action being
|
|
144
|
+
performed.
|
|
145
|
+
float
|
|
146
|
+
The computed reward for performing the action.
|
|
147
|
+
bool
|
|
148
|
+
Whether or not the episode is complete.
|
|
149
|
+
dict
|
|
150
|
+
The information gathered after completing the step.
|
|
151
|
+
"""
|
|
152
|
+
self.action_scheme.perform(self, action)
|
|
153
|
+
|
|
154
|
+
obs = self.observer.observe(self)
|
|
155
|
+
# Ensure observation is numpy array for GPU compatibility
|
|
156
|
+
obs = self._ensure_numpy(obs)
|
|
157
|
+
reward = self.reward_scheme.reward(self)
|
|
158
|
+
terminated = self.stopper.stop(self)
|
|
159
|
+
# Check if episode should be truncated due to max steps
|
|
160
|
+
truncated = (self.max_episode_steps is not None and
|
|
161
|
+
self.clock.step >= self.max_episode_steps)
|
|
162
|
+
info = self.informer.info(self)
|
|
163
|
+
|
|
164
|
+
self.clock.increment()
|
|
165
|
+
|
|
166
|
+
return obs, reward, terminated, truncated, info
|
|
167
|
+
|
|
168
|
+
def reset(self,seed = None, options = None) -> tuple["np.array", dict[str, Any]]:
|
|
169
|
+
"""Resets the environment.
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
obs : `np.array`
|
|
174
|
+
The first observation of the environment.
|
|
175
|
+
"""
|
|
176
|
+
if self.random_start_pct > 0.00:
|
|
177
|
+
size = len(self.observer.feed.process[-1].inputs[0].iterable)
|
|
178
|
+
random_start = randint(0, int(size * self.random_start_pct))
|
|
179
|
+
else:
|
|
180
|
+
random_start = 0
|
|
181
|
+
|
|
182
|
+
self.episode_id = str(uuid.uuid4())
|
|
183
|
+
self.clock.reset()
|
|
184
|
+
|
|
185
|
+
for c in self.components.values():
|
|
186
|
+
if hasattr(c, "reset"):
|
|
187
|
+
if isinstance(c, Observer):
|
|
188
|
+
c.reset(random_start=random_start)
|
|
189
|
+
else:
|
|
190
|
+
c.reset()
|
|
191
|
+
|
|
192
|
+
obs = self.observer.observe(self)
|
|
193
|
+
# Ensure observation is numpy array for GPU compatibility
|
|
194
|
+
obs = self._ensure_numpy(obs)
|
|
195
|
+
info = self.informer.info(self)
|
|
196
|
+
|
|
197
|
+
self.clock.increment()
|
|
198
|
+
|
|
199
|
+
return obs, info
|
|
200
|
+
|
|
201
|
+
def render(self, **kwargs) -> None:
|
|
202
|
+
"""Renders the environment."""
|
|
203
|
+
self.renderer.render(self, **kwargs)
|
|
204
|
+
|
|
205
|
+
def save(self) -> None:
|
|
206
|
+
"""Saves the rendered view of the environment."""
|
|
207
|
+
self.renderer.save()
|
|
208
|
+
|
|
209
|
+
def close(self) -> None:
|
|
210
|
+
"""Closes the environment."""
|
|
211
|
+
self.renderer.close()
|
tensortrade/feed/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
|
|
2
|
-
from . import api
|
|
3
|
-
from . import core
|
|
4
|
-
|
|
5
|
-
from .core import Stream, NameSpace, DataFeed
|
|
1
|
+
|
|
2
|
+
from . import api
|
|
3
|
+
from . import core
|
|
4
|
+
|
|
5
|
+
from .core import Stream, NameSpace, DataFeed
|
tensortrade/feed/api/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
|
|
2
|
-
from . import generic
|
|
3
|
-
from . import float
|
|
4
|
-
from . import boolean
|
|
5
|
-
from . import string
|
|
1
|
+
|
|
2
|
+
from . import generic
|
|
3
|
+
from . import float
|
|
4
|
+
from . import boolean
|
|
5
|
+
from . import string
|