tensortrade 1.0.0b0__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. tensortrade/__init__.py +23 -16
  2. tensortrade/agents/__init__.py +7 -7
  3. tensortrade/agents/a2c_agent.py +239 -237
  4. tensortrade/agents/agent.py +52 -49
  5. tensortrade/agents/dqn_agent.py +375 -202
  6. tensortrade/agents/parallel/__init__.py +5 -5
  7. tensortrade/agents/parallel/parallel_dqn_agent.py +172 -170
  8. tensortrade/agents/parallel/parallel_dqn_model.py +85 -83
  9. tensortrade/agents/parallel/parallel_dqn_optimizer.py +96 -90
  10. tensortrade/agents/parallel/parallel_dqn_trainer.py +97 -95
  11. tensortrade/agents/parallel/parallel_queue.py +95 -92
  12. tensortrade/agents/replay_memory.py +54 -52
  13. tensortrade/core/__init__.py +6 -6
  14. tensortrade/core/base.py +167 -173
  15. tensortrade/core/clock.py +48 -48
  16. tensortrade/core/component.py +129 -129
  17. tensortrade/core/context.py +182 -182
  18. tensortrade/core/exceptions.py +211 -211
  19. tensortrade/core/registry.py +45 -45
  20. tensortrade/data/__init__.py +1 -1
  21. tensortrade/data/cdd.py +152 -151
  22. tensortrade/env/__init__.py +2 -2
  23. tensortrade/env/default/__init__.py +96 -89
  24. tensortrade/env/default/actions.py +428 -399
  25. tensortrade/env/default/informers.py +14 -16
  26. tensortrade/env/default/observers.py +475 -284
  27. tensortrade/env/default/renderers.py +787 -586
  28. tensortrade/env/default/rewards.py +360 -240
  29. tensortrade/env/default/stoppers.py +33 -33
  30. tensortrade/env/generic/__init__.py +22 -22
  31. tensortrade/env/generic/components/__init__.py +13 -13
  32. tensortrade/env/generic/components/action_scheme.py +54 -54
  33. tensortrade/env/generic/components/informer.py +45 -45
  34. tensortrade/env/generic/components/observer.py +59 -59
  35. tensortrade/env/generic/components/renderer.py +86 -86
  36. tensortrade/env/generic/components/reward_scheme.py +44 -44
  37. tensortrade/env/generic/components/stopper.py +46 -46
  38. tensortrade/env/generic/environment.py +211 -163
  39. tensortrade/feed/__init__.py +5 -5
  40. tensortrade/feed/api/__init__.py +5 -5
  41. tensortrade/feed/api/boolean/__init__.py +44 -44
  42. tensortrade/feed/api/boolean/operations.py +20 -20
  43. tensortrade/feed/api/float/__init__.py +49 -48
  44. tensortrade/feed/api/float/accumulators.py +199 -199
  45. tensortrade/feed/api/float/imputation.py +40 -40
  46. tensortrade/feed/api/float/operations.py +233 -233
  47. tensortrade/feed/api/float/ordering.py +105 -105
  48. tensortrade/feed/api/float/utils.py +140 -140
  49. tensortrade/feed/api/float/window/__init__.py +3 -3
  50. tensortrade/feed/api/float/window/ewm.py +459 -452
  51. tensortrade/feed/api/float/window/expanding.py +189 -185
  52. tensortrade/feed/api/float/window/rolling.py +227 -217
  53. tensortrade/feed/api/generic/__init__.py +4 -4
  54. tensortrade/feed/api/generic/imputation.py +51 -51
  55. tensortrade/feed/api/generic/operators.py +118 -121
  56. tensortrade/feed/api/generic/reduce.py +119 -119
  57. tensortrade/feed/api/generic/warmup.py +54 -54
  58. tensortrade/feed/api/string/__init__.py +44 -43
  59. tensortrade/feed/api/string/operations.py +135 -131
  60. tensortrade/feed/core/__init__.py +3 -3
  61. tensortrade/feed/core/accessors.py +30 -30
  62. tensortrade/feed/core/base.py +634 -584
  63. tensortrade/feed/core/feed.py +120 -59
  64. tensortrade/feed/core/methods.py +37 -37
  65. tensortrade/feed/core/mixins.py +23 -23
  66. tensortrade/feed/core/operators.py +174 -174
  67. tensortrade/oms/__init__.py +2 -2
  68. tensortrade/oms/exchanges/__init__.py +1 -1
  69. tensortrade/oms/exchanges/exchange.py +176 -164
  70. tensortrade/oms/instruments/__init__.py +5 -5
  71. tensortrade/oms/instruments/exchange_pair.py +44 -44
  72. tensortrade/oms/instruments/instrument.py +161 -161
  73. tensortrade/oms/instruments/quantity.py +321 -318
  74. tensortrade/oms/instruments/trading_pair.py +58 -58
  75. tensortrade/oms/orders/__init__.py +13 -13
  76. tensortrade/oms/orders/broker.py +129 -125
  77. tensortrade/oms/orders/create.py +312 -312
  78. tensortrade/oms/orders/criteria.py +218 -218
  79. tensortrade/oms/orders/order.py +368 -368
  80. tensortrade/oms/orders/order_listener.py +62 -62
  81. tensortrade/oms/orders/order_spec.py +102 -102
  82. tensortrade/oms/orders/trade.py +159 -159
  83. tensortrade/oms/services/__init__.py +2 -2
  84. tensortrade/oms/services/execution/__init__.py +4 -4
  85. tensortrade/oms/services/execution/simulated.py +197 -183
  86. tensortrade/oms/services/slippage/__init__.py +21 -21
  87. tensortrade/oms/services/slippage/random_slippage_model.py +56 -56
  88. tensortrade/oms/services/slippage/slippage_model.py +46 -46
  89. tensortrade/oms/wallets/__init__.py +20 -20
  90. tensortrade/oms/wallets/ledger.py +92 -92
  91. tensortrade/oms/wallets/portfolio.py +330 -329
  92. tensortrade/oms/wallets/wallet.py +376 -365
  93. tensortrade/stochastic/__init__.py +12 -12
  94. tensortrade/stochastic/processes/brownian_motion.py +55 -55
  95. tensortrade/stochastic/processes/cox.py +103 -103
  96. tensortrade/stochastic/processes/fbm.py +88 -88
  97. tensortrade/stochastic/processes/gbm.py +129 -129
  98. tensortrade/stochastic/processes/heston.py +281 -281
  99. tensortrade/stochastic/processes/merton.py +91 -91
  100. tensortrade/stochastic/processes/ornstein_uhlenbeck.py +113 -113
  101. tensortrade/stochastic/utils/__init__.py +2 -2
  102. tensortrade/stochastic/utils/helpers.py +180 -179
  103. tensortrade/stochastic/utils/parameters.py +172 -172
  104. tensortrade/version.py +1 -1
  105. tensortrade-1.0.4.dist-info/METADATA +65 -0
  106. tensortrade-1.0.4.dist-info/RECORD +114 -0
  107. {tensortrade-1.0.0b0.dist-info → tensortrade-1.0.4.dist-info}/WHEEL +1 -1
  108. {tensortrade-1.0.0b0.dist-info → tensortrade-1.0.4.dist-info/licenses}/LICENSE +200 -200
  109. tensortrade-1.0.0b0.dist-info/METADATA +0 -74
  110. tensortrade-1.0.0b0.dist-info/RECORD +0 -114
  111. {tensortrade-1.0.0b0.dist-info → tensortrade-1.0.4.dist-info}/top_level.txt +0 -0
@@ -1,44 +1,44 @@
1
- # Copyright 2020 The TensorTrade Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License
14
-
15
- from abc import abstractmethod
16
-
17
- from tensortrade.core.component import Component
18
- from tensortrade.core.base import TimeIndexed
19
-
20
-
21
- class RewardScheme(Component, TimeIndexed):
22
- """A component to compute the reward at each step of an episode."""
23
-
24
- registered_name = "rewards"
25
-
26
- @abstractmethod
27
- def reward(self, env: 'TradingEnv') -> float:
28
- """Computes the reward for the current step of an episode.
29
-
30
- Parameters
31
- ----------
32
- env : `TradingEnv`
33
- The trading environment
34
-
35
- Returns
36
- -------
37
- float
38
- The computed reward.
39
- """
40
- raise NotImplementedError()
41
-
42
- def reset(self) -> None:
43
- """Resets the reward scheme."""
44
- pass
1
+ # Copyright 2020 The TensorTrade Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ from abc import abstractmethod
16
+
17
+ from tensortrade.core.component import Component
18
+ from tensortrade.core.base import TimeIndexed
19
+
20
+
21
+ class RewardScheme(Component, TimeIndexed):
22
+ """A component to compute the reward at each step of an episode."""
23
+
24
+ registered_name = "rewards"
25
+
26
+ @abstractmethod
27
+ def reward(self, env: 'TradingEnv') -> float:
28
+ """Computes the reward for the current step of an episode.
29
+
30
+ Parameters
31
+ ----------
32
+ env : `TradingEnv`
33
+ The trading environment
34
+
35
+ Returns
36
+ -------
37
+ float
38
+ The computed reward.
39
+ """
40
+ raise NotImplementedError()
41
+
42
+ def reset(self) -> None:
43
+ """Resets the reward scheme."""
44
+ pass
@@ -1,46 +1,46 @@
1
- # Copyright 2020 The TensorTrade Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License
14
-
15
- from abc import abstractmethod
16
-
17
- from tensortrade.core.component import Component
18
- from tensortrade.core.base import TimeIndexed
19
-
20
-
21
- class Stopper(Component, TimeIndexed):
22
- """A component for determining if the environment satisfies a defined
23
- stopping criteria.
24
- """
25
-
26
- registered_name = "stopper"
27
-
28
- @abstractmethod
29
- def stop(self, env: 'TradingEnv') -> bool:
30
- """Computes if the environment satisfies the defined stopping criteria.
31
-
32
- Parameters
33
- ----------
34
- env : `TradingEnv`
35
- The trading environment.
36
-
37
- Returns
38
- -------
39
- bool
40
- If the environment should stop or continue.
41
- """
42
- raise NotImplementedError()
43
-
44
- def reset(self) -> None:
45
- """Resets the stopper."""
46
- pass
1
+ # Copyright 2020 The TensorTrade Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ from abc import abstractmethod
16
+
17
+ from tensortrade.core.component import Component
18
+ from tensortrade.core.base import TimeIndexed
19
+
20
+
21
+ class Stopper(Component, TimeIndexed):
22
+ """A component for determining if the environment satisfies a defined
23
+ stopping criteria.
24
+ """
25
+
26
+ registered_name = "stopper"
27
+
28
+ @abstractmethod
29
+ def stop(self, env: 'TradingEnv') -> bool:
30
+ """Computes if the environment satisfies the defined stopping criteria.
31
+
32
+ Parameters
33
+ ----------
34
+ env : `TradingEnv`
35
+ The trading environment.
36
+
37
+ Returns
38
+ -------
39
+ bool
40
+ If the environment should stop or continue.
41
+ """
42
+ raise NotImplementedError()
43
+
44
+ def reset(self) -> None:
45
+ """Resets the stopper."""
46
+ pass
@@ -1,163 +1,211 @@
1
- # Copyright 2020 The TensorTrade Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License
14
-
15
- import uuid
16
- import logging
17
-
18
- from typing import Dict, Any, Tuple
19
-
20
- import gym
21
- import numpy as np
22
-
23
- from tensortrade.core import TimeIndexed, Clock, Component
24
- from tensortrade.env.generic import (
25
- ActionScheme,
26
- RewardScheme,
27
- Observer,
28
- Stopper,
29
- Informer,
30
- Renderer
31
- )
32
-
33
-
34
- class TradingEnv(gym.Env, TimeIndexed):
35
- """A trading environment made for use with Gym-compatible reinforcement
36
- learning algorithms.
37
-
38
- Parameters
39
- ----------
40
- action_scheme : `ActionScheme`
41
- A component for generating an action to perform at each step of the
42
- environment.
43
- reward_scheme : `RewardScheme`
44
- A component for computing reward after each step of the environment.
45
- observer : `Observer`
46
- A component for generating observations after each step of the
47
- environment.
48
- informer : `Informer`
49
- A component for providing information after each step of the
50
- environment.
51
- renderer : `Renderer`
52
- A component for rendering the environment.
53
- kwargs : keyword arguments
54
- Additional keyword arguments needed to create the environment.
55
- """
56
-
57
- agent_id: str = None
58
- episode_id: str = None
59
-
60
- def __init__(self,
61
- action_scheme: ActionScheme,
62
- reward_scheme: RewardScheme,
63
- observer: Observer,
64
- stopper: Stopper,
65
- informer: Informer,
66
- renderer: Renderer,
67
- **kwargs) -> None:
68
- super().__init__()
69
- self.clock = Clock()
70
-
71
- self.action_scheme = action_scheme
72
- self.reward_scheme = reward_scheme
73
- self.observer = observer
74
- self.stopper = stopper
75
- self.informer = informer
76
- self.renderer = renderer
77
-
78
- for c in self.components.values():
79
- c.clock = self.clock
80
-
81
- self.action_space = action_scheme.action_space
82
- self.observation_space = observer.observation_space
83
-
84
- self._enable_logger = kwargs.get('enable_logger', False)
85
- if self._enable_logger:
86
- self.logger = logging.getLogger(kwargs.get('logger_name', __name__))
87
- self.logger.setLevel(kwargs.get('log_level', logging.DEBUG))
88
-
89
- @property
90
- def components(self) -> 'Dict[str, Component]':
91
- """The components of the environment. (`Dict[str,Component]`, read-only)"""
92
- return {
93
- "action_scheme": self.action_scheme,
94
- "reward_scheme": self.reward_scheme,
95
- "observer": self.observer,
96
- "stopper": self.stopper,
97
- "informer": self.informer,
98
- "renderer": self.renderer
99
- }
100
-
101
- def step(self, action: Any) -> 'Tuple[np.array, float, bool, dict]':
102
- """Makes on step through the environment.
103
-
104
- Parameters
105
- ----------
106
- action : Any
107
- An action to perform on the environment.
108
-
109
- Returns
110
- -------
111
- `np.array`
112
- The observation of the environment after the action being
113
- performed.
114
- float
115
- The computed reward for performing the action.
116
- bool
117
- Whether or not the episode is complete.
118
- dict
119
- The information gathered after completing the step.
120
- """
121
- self.action_scheme.perform(self, action)
122
-
123
- obs = self.observer.observe(self)
124
- reward = self.reward_scheme.reward(self)
125
- done = self.stopper.stop(self)
126
- info = self.informer.info(self)
127
-
128
- self.clock.increment()
129
-
130
- return obs, reward, done, info
131
-
132
- def reset(self) -> 'np.array':
133
- """Resets the environment.
134
-
135
- Returns
136
- -------
137
- obs : `np.array`
138
- The first observation of the environment.
139
- """
140
- self.episode_id = str(uuid.uuid4())
141
- self.clock.reset()
142
-
143
- for c in self.components.values():
144
- if hasattr(c, "reset"):
145
- c.reset()
146
-
147
- obs = self.observer.observe(self)
148
-
149
- self.clock.increment()
150
-
151
- return obs
152
-
153
- def render(self, **kwargs) -> None:
154
- """Renders the environment."""
155
- self.renderer.render(self, **kwargs)
156
-
157
- def save(self) -> None:
158
- """Saves the rendered view of the environment."""
159
- self.renderer.save()
160
-
161
- def close(self) -> None:
162
- """Closes the environment."""
163
- self.renderer.close()
1
+ # Copyright 2020 The TensorTrade Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ import uuid
16
+ import logging
17
+
18
+ from typing import Dict, Any, Tuple, Optional
19
+ from random import randint
20
+
21
+ import gymnasium
22
+ import numpy as np
23
+
24
+ from tensortrade.core import TimeIndexed, Clock, Component
25
+ from tensortrade.env.generic import (
26
+ ActionScheme,
27
+ RewardScheme,
28
+ Observer,
29
+ Stopper,
30
+ Informer,
31
+ Renderer
32
+ )
33
+
34
+
35
+ class TradingEnv(gymnasium.Env, TimeIndexed):
36
+ """A trading environment made for use with Gym-compatible reinforcement
37
+ learning algorithms.
38
+
39
+ Parameters
40
+ ----------
41
+ action_scheme : `ActionScheme`
42
+ A component for generating an action to perform at each step of the
43
+ environment.
44
+ reward_scheme : `RewardScheme`
45
+ A component for computing reward after each step of the environment.
46
+ observer : `Observer`
47
+ A component for generating observations after each step of the
48
+ environment.
49
+ informer : `Informer`
50
+ A component for providing information after each step of the
51
+ environment.
52
+ renderer : `Renderer`
53
+ A component for rendering the environment.
54
+ kwargs : keyword arguments
55
+ Additional keyword arguments needed to create the environment.
56
+ """
57
+
58
+ agent_id: str = None
59
+ episode_id: str = None
60
+
61
+ def __init__(self,
62
+ action_scheme: ActionScheme,
63
+ reward_scheme: RewardScheme,
64
+ observer: Observer,
65
+ stopper: Stopper,
66
+ informer: Informer,
67
+ renderer: Renderer,
68
+ min_periods: int = None,
69
+ max_episode_steps: int = None,
70
+ random_start_pct: float = 0.00,
71
+ device: Optional[str] = None,
72
+ **kwargs) -> None:
73
+ super().__init__()
74
+ self.clock = Clock()
75
+
76
+ self.action_scheme = action_scheme
77
+ self.reward_scheme = reward_scheme
78
+ self.observer = observer
79
+ self.stopper = stopper
80
+ self.informer = informer
81
+ self.renderer = renderer
82
+ self.min_periods = min_periods
83
+ self.max_episode_steps = max_episode_steps
84
+ self.random_start_pct = random_start_pct
85
+ self.device = device
86
+
87
+ for c in self.components.values():
88
+ c.clock = self.clock
89
+
90
+ self.action_space = action_scheme.action_space
91
+ self.observation_space = observer.observation_space
92
+
93
+ self._enable_logger = kwargs.get('enable_logger', False)
94
+ if self._enable_logger:
95
+ self.logger = logging.getLogger(kwargs.get('logger_name', __name__))
96
+ self.logger.setLevel(kwargs.get('log_level', logging.DEBUG))
97
+
98
+ def _ensure_numpy(self, obs: Any) -> np.ndarray:
99
+ """Ensure observation is returned as numpy array for GPU compatibility.
100
+
101
+ Parameters
102
+ ----------
103
+ obs : Any
104
+ The observation to convert
105
+
106
+ Returns
107
+ -------
108
+ np.ndarray
109
+ The observation as a numpy array
110
+ """
111
+ if hasattr(obs, 'cpu'): # PyTorch tensor
112
+ return obs.cpu().numpy()
113
+ elif hasattr(obs, 'numpy'): # TensorFlow tensor
114
+ return obs.numpy()
115
+ elif isinstance(obs, np.ndarray):
116
+ return obs
117
+ else:
118
+ return np.array(obs)
119
+
120
+ @property
121
+ def components(self) -> 'Dict[str, Component]':
122
+ """The components of the environment. (`Dict[str,Component]`, read-only)"""
123
+ return {
124
+ "action_scheme": self.action_scheme,
125
+ "reward_scheme": self.reward_scheme,
126
+ "observer": self.observer,
127
+ "stopper": self.stopper,
128
+ "informer": self.informer,
129
+ "renderer": self.renderer
130
+ }
131
+
132
+ def step(self, action: Any) -> 'Tuple[np.array, float, bool, dict]':
133
+ """Makes one step through the environment.
134
+
135
+ Parameters
136
+ ----------
137
+ action : Any
138
+ An action to perform on the environment.
139
+
140
+ Returns
141
+ -------
142
+ `np.array`
143
+ The observation of the environment after the action being
144
+ performed.
145
+ float
146
+ The computed reward for performing the action.
147
+ bool
148
+ Whether or not the episode is complete.
149
+ dict
150
+ The information gathered after completing the step.
151
+ """
152
+ self.action_scheme.perform(self, action)
153
+
154
+ obs = self.observer.observe(self)
155
+ # Ensure observation is numpy array for GPU compatibility
156
+ obs = self._ensure_numpy(obs)
157
+ reward = self.reward_scheme.reward(self)
158
+ terminated = self.stopper.stop(self)
159
+ # Check if episode should be truncated due to max steps
160
+ truncated = (self.max_episode_steps is not None and
161
+ self.clock.step >= self.max_episode_steps)
162
+ info = self.informer.info(self)
163
+
164
+ self.clock.increment()
165
+
166
+ return obs, reward, terminated, truncated, info
167
+
168
+ def reset(self,seed = None, options = None) -> tuple["np.array", dict[str, Any]]:
169
+ """Resets the environment.
170
+
171
+ Returns
172
+ -------
173
+ obs : `np.array`
174
+ The first observation of the environment.
175
+ """
176
+ if self.random_start_pct > 0.00:
177
+ size = len(self.observer.feed.process[-1].inputs[0].iterable)
178
+ random_start = randint(0, int(size * self.random_start_pct))
179
+ else:
180
+ random_start = 0
181
+
182
+ self.episode_id = str(uuid.uuid4())
183
+ self.clock.reset()
184
+
185
+ for c in self.components.values():
186
+ if hasattr(c, "reset"):
187
+ if isinstance(c, Observer):
188
+ c.reset(random_start=random_start)
189
+ else:
190
+ c.reset()
191
+
192
+ obs = self.observer.observe(self)
193
+ # Ensure observation is numpy array for GPU compatibility
194
+ obs = self._ensure_numpy(obs)
195
+ info = self.informer.info(self)
196
+
197
+ self.clock.increment()
198
+
199
+ return obs, info
200
+
201
+ def render(self, **kwargs) -> None:
202
+ """Renders the environment."""
203
+ self.renderer.render(self, **kwargs)
204
+
205
+ def save(self) -> None:
206
+ """Saves the rendered view of the environment."""
207
+ self.renderer.save()
208
+
209
+ def close(self) -> None:
210
+ """Closes the environment."""
211
+ self.renderer.close()
@@ -1,5 +1,5 @@
1
-
2
- from . import api
3
- from . import core
4
-
5
- from .core import Stream, NameSpace, DataFeed
1
+
2
+ from . import api
3
+ from . import core
4
+
5
+ from .core import Stream, NameSpace, DataFeed
@@ -1,5 +1,5 @@
1
-
2
- from . import generic
3
- from . import float
4
- from . import boolean
5
- from . import string
1
+
2
+ from . import generic
3
+ from . import float
4
+ from . import boolean
5
+ from . import string