analyser_hj3415 3.4.2__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- analyser_hj3415/analyser/eval/red.py +66 -24
- analyser_hj3415/analyser/tsa/lstm.py +138 -169
- analyser_hj3415/analyser/tsa/prophet.py +172 -156
- analyser_hj3415/cli.py +27 -44
- {analyser_hj3415-3.4.2.dist-info → analyser_hj3415-4.0.0.dist-info}/METADATA +1 -1
- analyser_hj3415-4.0.0.dist-info/RECORD +17 -0
- analyser_hj3415/analyser/compile.py +0 -355
- analyser_hj3415/workroom/__init__.py +0 -0
- analyser_hj3415/workroom/mysklearn.py +0 -50
- analyser_hj3415/workroom/mysklearn2.py +0 -39
- analyser_hj3415/workroom/score.py +0 -342
- analyser_hj3415/workroom/trash.py +0 -289
- analyser_hj3415-3.4.2.dist-info/RECORD +0 -23
- {analyser_hj3415-3.4.2.dist-info → analyser_hj3415-4.0.0.dist-info}/WHEEL +0 -0
- {analyser_hj3415-3.4.2.dist-info → analyser_hj3415-4.0.0.dist-info}/entry_points.txt +0 -0
@@ -1,355 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
from collections import OrderedDict
|
3
|
-
from typing import Union
|
4
|
-
from dataclasses import dataclass
|
5
|
-
|
6
|
-
from db_hj3415 import myredis
|
7
|
-
from utils_hj3415 import tools, setup_logger
|
8
|
-
|
9
|
-
from analyser_hj3415.analyser import tsa, eval, MIs
|
10
|
-
|
11
|
-
mylogger = setup_logger(__name__,'WARNING')
|
12
|
-
expire_time = tools.to_int(os.getenv('DEFAULT_EXPIRE_TIME_H', 48)) * 3600
|
13
|
-
|
14
|
-
@dataclass
|
15
|
-
class MICompileData:
|
16
|
-
"""
|
17
|
-
MI(Market Index) 데이터를 컴파일하여 저장하는 데이터 클래스.
|
18
|
-
|
19
|
-
속성:
|
20
|
-
mi_type (str): 시장 지수 유형.
|
21
|
-
prophet_data (tsa.ProphetData): Prophet 예측 데이터.
|
22
|
-
lstm_grade (tsa.LSTMGrade): LSTM 등급 데이터.
|
23
|
-
is_lstm_up (bool): LSTM 상승 여부.
|
24
|
-
is_prophet_up (bool): Prophet 상승 여부.
|
25
|
-
lstm_html (str): LSTM 시각화 HTML.
|
26
|
-
prophet_html (str): Prophet 시각화 HTML.
|
27
|
-
"""
|
28
|
-
mi_type: str
|
29
|
-
|
30
|
-
prophet_data: tsa.ProphetData
|
31
|
-
lstm_grade: tsa.LSTMGrade
|
32
|
-
|
33
|
-
is_lstm_up: bool = False
|
34
|
-
is_prophet_up: bool = False
|
35
|
-
|
36
|
-
lstm_html: str = ''
|
37
|
-
prophet_html: str = ''
|
38
|
-
|
39
|
-
|
40
|
-
class MICompile:
|
41
|
-
"""
|
42
|
-
MI(Market Index) 데이터를 컴파일하는 클래스.
|
43
|
-
|
44
|
-
메서드:
|
45
|
-
get(refresh=False) -> MICompileData:
|
46
|
-
MI 데이터를 컴파일하거나 캐시에서 가져옵니다.
|
47
|
-
|
48
|
-
analyser_lstm_all_mi(refresh: bool):
|
49
|
-
모든 MI에 대해 LSTM 예측 및 초기화 수행.
|
50
|
-
"""
|
51
|
-
def __init__(self, mi_type: str):
|
52
|
-
"""
|
53
|
-
MICompile 객체를 초기화합니다.
|
54
|
-
|
55
|
-
매개변수:
|
56
|
-
mi_type (str): 시장 지수 유형.
|
57
|
-
"""
|
58
|
-
assert mi_type in MIs._fields, f"Invalid MI type ({MIs._fields})"
|
59
|
-
self._mi_type = mi_type
|
60
|
-
|
61
|
-
@property
|
62
|
-
def mi_type(self) -> str:
|
63
|
-
"""
|
64
|
-
MI 유형을 반환합니다.
|
65
|
-
|
66
|
-
반환값:
|
67
|
-
str: MI 유형.
|
68
|
-
"""
|
69
|
-
return self._mi_type
|
70
|
-
|
71
|
-
@mi_type.setter
|
72
|
-
def mi_type(self, mi_type: str):
|
73
|
-
"""
|
74
|
-
MI 유형을 변경합니다.
|
75
|
-
|
76
|
-
매개변수:
|
77
|
-
mi_type (str): 새로 설정할 MI 유형.
|
78
|
-
"""
|
79
|
-
assert mi_type in MIs._fields, f"Invalid MI type ({MIs._fields})"
|
80
|
-
self._mi_type = mi_type
|
81
|
-
|
82
|
-
def get(self, refresh=False) -> MICompileData:
|
83
|
-
"""
|
84
|
-
MI 데이터를 컴파일하거나 캐시에서 가져옵니다.
|
85
|
-
|
86
|
-
매개변수:
|
87
|
-
refresh (bool): 데이터를 새로 가져올지 여부.
|
88
|
-
|
89
|
-
반환값:
|
90
|
-
MICompileData: 컴파일된 MI 데이터.
|
91
|
-
"""
|
92
|
-
print(f"{self.mi_type}의 compiling을 시작합니다.")
|
93
|
-
redis_name = self.mi_type + '_mi_compile'
|
94
|
-
print(
|
95
|
-
f"redisname: '{redis_name}' / refresh : {refresh} / expire_time : {expire_time / 3600}h")
|
96
|
-
|
97
|
-
def fetch_mi_compile_data() -> MICompileData:
|
98
|
-
prophet = tsa.MIProphet(self.mi_type)
|
99
|
-
lstm = tsa.MILSTM(self.mi_type)
|
100
|
-
|
101
|
-
data = MICompileData(
|
102
|
-
mi_type=self.mi_type,
|
103
|
-
prophet_data=prophet.generate_data(refresh=refresh),
|
104
|
-
lstm_grade=lstm.get_final_predictions(refresh=refresh)[1],
|
105
|
-
)
|
106
|
-
data.is_lstm_up = lstm.is_lstm_up()
|
107
|
-
data.is_prophet_up = prophet.is_prophet_up(refresh=False)
|
108
|
-
data.lstm_html = lstm.export(refresh=False)
|
109
|
-
data.prophet_html = prophet.export()
|
110
|
-
return data
|
111
|
-
|
112
|
-
mi_compile_data = myredis.Base.fetch_and_cache_data(redis_name, refresh, fetch_mi_compile_data, timer=expire_time)
|
113
|
-
return mi_compile_data
|
114
|
-
|
115
|
-
@staticmethod
|
116
|
-
def caching_mi_compile_all(refresh: bool):
|
117
|
-
"""
|
118
|
-
모든 MI(Market Index)에 대해 MICompileData를 캐싱합니다..
|
119
|
-
|
120
|
-
매개변수:
|
121
|
-
refresh (bool): 데이터를 새로 가져올지 여부.
|
122
|
-
"""
|
123
|
-
mi_compile = MICompile('WTI')
|
124
|
-
print(f"*** MICompileData caching Market Index items ***")
|
125
|
-
for mi_type in MIs._fields:
|
126
|
-
mi_compile.mi_type = mi_type
|
127
|
-
print(f"{mi_type}")
|
128
|
-
mi_compile_data = mi_compile.get(refresh=refresh)
|
129
|
-
print(mi_compile_data)
|
130
|
-
|
131
|
-
|
132
|
-
@dataclass
|
133
|
-
class CorpCompileData:
|
134
|
-
"""
|
135
|
-
기업 데이터를 컴파일하여 저장하는 데이터 클래스.
|
136
|
-
|
137
|
-
속성:
|
138
|
-
code (str): 기업 코드.
|
139
|
-
name (str): 기업 이름.
|
140
|
-
red_data (eval.RedData): RED 분석 데이터.
|
141
|
-
mil_data (eval.MilData): MIL 분석 데이터.
|
142
|
-
prophet_data (tsa.ProphetData): Prophet 예측 데이터.
|
143
|
-
lstm_grade (tsa.LSTMGrade): LSTM 등급 데이터.
|
144
|
-
is_lstm_up (bool): LSTM 상승 여부.
|
145
|
-
is_prophet_up (bool): Prophet 상승 여부.
|
146
|
-
lstm_html (str): LSTM 시각화 HTML.
|
147
|
-
prophet_html (str): Prophet 시각화 HTML.
|
148
|
-
"""
|
149
|
-
code: str
|
150
|
-
name: str
|
151
|
-
|
152
|
-
red_data: eval.RedData
|
153
|
-
mil_data: eval.MilData
|
154
|
-
|
155
|
-
prophet_data: tsa.ProphetData
|
156
|
-
lstm_grade: tsa.LSTMGrade
|
157
|
-
|
158
|
-
is_lstm_up: bool = False
|
159
|
-
is_prophet_up: bool = False
|
160
|
-
|
161
|
-
lstm_html: str = ''
|
162
|
-
prophet_html: str = ''
|
163
|
-
|
164
|
-
|
165
|
-
class CorpCompile:
|
166
|
-
"""
|
167
|
-
기업 데이터를 컴파일하는 클래스.
|
168
|
-
|
169
|
-
메서드:
|
170
|
-
get(refresh=False) -> CorpCompileData:
|
171
|
-
기업 데이터를 컴파일하거나 캐시에서 가져옵니다.
|
172
|
-
|
173
|
-
red_ranking(expect_earn: float = 0.06, refresh=False) -> OrderedDict:
|
174
|
-
RED 데이터를 기반으로 기업 순위를 계산합니다.
|
175
|
-
|
176
|
-
prophet_ranking(refresh=False, top: Union[int, str]='all') -> OrderedDict:
|
177
|
-
Prophet 데이터를 기반으로 기업 순위를 계산합니다.
|
178
|
-
|
179
|
-
analyse_lstm_topn(refresh: bool, top=40):
|
180
|
-
상위 N개의 기업에 대해 LSTM 예측 수행.
|
181
|
-
"""
|
182
|
-
def __init__(self, code: str, expect_earn=0.06):
|
183
|
-
"""
|
184
|
-
CorpCompile 객체를 초기화합니다.
|
185
|
-
|
186
|
-
매개변수:
|
187
|
-
code (str): 기업 코드.
|
188
|
-
expect_earn (float, optional): 예상 수익률. 기본값은 0.06.
|
189
|
-
"""
|
190
|
-
assert tools.is_6digit(code), f'Invalid value : {code}'
|
191
|
-
self._code = code
|
192
|
-
self.expect_earn = expect_earn
|
193
|
-
|
194
|
-
@property
|
195
|
-
def code(self) -> str:
|
196
|
-
"""
|
197
|
-
기업 코드를 반환합니다.
|
198
|
-
|
199
|
-
반환값:
|
200
|
-
str: 기업 코드.
|
201
|
-
"""
|
202
|
-
return self._code
|
203
|
-
|
204
|
-
@code.setter
|
205
|
-
def code(self, code: str):
|
206
|
-
"""
|
207
|
-
기업 코드를 변경합니다.
|
208
|
-
|
209
|
-
매개변수:
|
210
|
-
code (str): 새로 설정할 기업 코드.
|
211
|
-
"""
|
212
|
-
assert tools.is_6digit(code), f'Invalid value : {code}'
|
213
|
-
mylogger.info(f'change code : {self.code} -> {code}')
|
214
|
-
self._code = code
|
215
|
-
|
216
|
-
def get(self, refresh=False) -> CorpCompileData:
|
217
|
-
"""
|
218
|
-
기업 데이터를 컴파일하여 캐시에 저장하거나 캐시에서 가져옵니다.
|
219
|
-
|
220
|
-
매개변수:
|
221
|
-
refresh (bool): 데이터를 새로 가져올지 여부.
|
222
|
-
|
223
|
-
반환값:
|
224
|
-
CorpCompileData: 컴파일된 기업 데이터.
|
225
|
-
"""
|
226
|
-
print(f"{self.code}의 compiling을 시작합니다.")
|
227
|
-
redis_name = self.code + '_corp_compile'
|
228
|
-
print(
|
229
|
-
f"redisname: '{redis_name}' / refresh : {refresh} / expire_time : {expire_time/3600}h")
|
230
|
-
|
231
|
-
def fetch_corp_compile_data() -> CorpCompileData:
|
232
|
-
prophet = tsa.CorpProphet(self.code)
|
233
|
-
lstm = tsa.CorpLSTM(self.code)
|
234
|
-
|
235
|
-
data = CorpCompileData(
|
236
|
-
code=self.code,
|
237
|
-
name=myredis.Corps(self.code,'c101').get_name(data_from='mongo'),
|
238
|
-
red_data=eval.Red(self.code, self.expect_earn).get(refresh=refresh, verbose=False),
|
239
|
-
mil_data=eval.Mil(self.code).get(refresh=refresh, verbose=False),
|
240
|
-
prophet_data=prophet.generate_data(refresh=refresh),
|
241
|
-
lstm_grade=lstm.get_final_predictions(refresh=refresh)[1],
|
242
|
-
)
|
243
|
-
|
244
|
-
data.is_lstm_up = lstm.is_lstm_up()
|
245
|
-
data.is_prophet_up = prophet.is_prophet_up(refresh=False)
|
246
|
-
data.lstm_html = lstm.export(refresh=False)
|
247
|
-
data.prophet_html = prophet.export()
|
248
|
-
return data
|
249
|
-
|
250
|
-
corp_compile_data = myredis.Base.fetch_and_cache_data(redis_name, refresh, fetch_corp_compile_data, timer=expire_time)
|
251
|
-
return corp_compile_data
|
252
|
-
|
253
|
-
@staticmethod
|
254
|
-
def red_ranking(expect_earn: float = 0.06, refresh=False) -> OrderedDict:
|
255
|
-
"""
|
256
|
-
RED 데이터를 기반으로 기업 순위를 계산합니다.
|
257
|
-
|
258
|
-
매개변수:
|
259
|
-
expect_earn (float, optional): 예상 수익률. 기본값은 0.06.
|
260
|
-
refresh (bool): 데이터를 새로 가져올지 여부.
|
261
|
-
|
262
|
-
반환값:
|
263
|
-
OrderedDict: RED 점수를 기준으로 정렬된 기업 순위.
|
264
|
-
"""
|
265
|
-
redis_name = 'red_ranking_prev_expect_earn'
|
266
|
-
pee = tools.to_float(myredis.Base.get_value(redis_name))
|
267
|
-
if pee != expect_earn:
|
268
|
-
mylogger.warning(
|
269
|
-
f"expect earn : {expect_earn} / prev expect earn : {pee} 두 값이 달라 refresh = True"
|
270
|
-
)
|
271
|
-
myredis.Base.set_value(redis_name, str(expect_earn))
|
272
|
-
refresh = True
|
273
|
-
|
274
|
-
print("**** Start red_ranking... ****")
|
275
|
-
redis_name = 'red_ranking'
|
276
|
-
print(
|
277
|
-
f"redisname: '{redis_name}' / expect_earn: {expect_earn} / refresh : {refresh} / expire_time : {expire_time / 3600}h")
|
278
|
-
|
279
|
-
def fetch_ranking(refresh_in: bool) -> dict:
|
280
|
-
data = {}
|
281
|
-
red = eval.Red(code='005930', expect_earn=expect_earn)
|
282
|
-
for i, code in enumerate(myredis.Corps.list_all_codes()):
|
283
|
-
red.code = code
|
284
|
-
red_score = red.get(refresh=refresh_in, verbose=False).score
|
285
|
-
if red_score > 0:
|
286
|
-
data[code] = red_score
|
287
|
-
print(f"{i}: {red} - {red_score}")
|
288
|
-
return data
|
289
|
-
|
290
|
-
data_dict = myredis.Base.fetch_and_cache_data(redis_name, refresh, fetch_ranking, refresh, timer=expire_time)
|
291
|
-
|
292
|
-
return OrderedDict(sorted(data_dict.items(), key=lambda item: item[1], reverse=True))
|
293
|
-
|
294
|
-
@staticmethod
|
295
|
-
def prophet_ranking(refresh=False, top: Union[int, str]='all') -> OrderedDict:
|
296
|
-
"""
|
297
|
-
Prophet 데이터를 기반으로 기업 순위를 계산합니다.
|
298
|
-
|
299
|
-
매개변수:
|
300
|
-
refresh (bool): 데이터를 새로 가져올지 여부.
|
301
|
-
top (Union[int, str], optional): 상위 기업 개수. 'all'이면 전체 반환. 기본값은 'all'.
|
302
|
-
|
303
|
-
반환값:
|
304
|
-
OrderedDict: Prophet 점수를 기준으로 정렬된 기업 순위.
|
305
|
-
"""
|
306
|
-
print("**** Start Compiling scores and sorting... ****")
|
307
|
-
redis_name = 'prophet_ranking'
|
308
|
-
|
309
|
-
print(
|
310
|
-
f"redisname: '{redis_name}' / refresh : {refresh} / expire_time : {expire_time/3600}h")
|
311
|
-
|
312
|
-
def fetch_prophet_ranking() -> dict:
|
313
|
-
data = {}
|
314
|
-
c = tsa.CorpProphet('005930')
|
315
|
-
for code in myredis.Corps.list_all_codes():
|
316
|
-
try:
|
317
|
-
c.code = code
|
318
|
-
except ValueError:
|
319
|
-
mylogger.error(f'prophet ranking error : {code}')
|
320
|
-
continue
|
321
|
-
score= c.generate_data(refresh=refresh).score
|
322
|
-
print(f'{code} compiled : {score}')
|
323
|
-
data[code] = score
|
324
|
-
return data
|
325
|
-
|
326
|
-
data_dict = myredis.Base.fetch_and_cache_data(redis_name, refresh, fetch_prophet_ranking, timer=expire_time)
|
327
|
-
|
328
|
-
ranking = OrderedDict(sorted(data_dict.items(), key=lambda x: x[1], reverse=True))
|
329
|
-
|
330
|
-
if top == 'all':
|
331
|
-
return ranking
|
332
|
-
else:
|
333
|
-
if isinstance(top, int):
|
334
|
-
return OrderedDict(list(ranking.items())[:top])
|
335
|
-
else:
|
336
|
-
raise ValueError("top 인자는 'all' 이나 int형 이어야 합니다.")
|
337
|
-
|
338
|
-
@staticmethod
|
339
|
-
def caching_corp_compile_topn(refresh: bool, top=40):
|
340
|
-
"""
|
341
|
-
상위 N개의 기업에 대해 CorpCompileData를 수집합니다..
|
342
|
-
|
343
|
-
매개변수:
|
344
|
-
refresh (bool): 데이터를 새로 가져올지 여부.
|
345
|
-
top (int, optional): 상위 기업 개수. 기본값은 40.
|
346
|
-
"""
|
347
|
-
ranking_topn = CorpCompile.prophet_ranking(refresh=False, top=top)
|
348
|
-
mylogger.info(ranking_topn)
|
349
|
-
corp_compile = CorpCompile('005930')
|
350
|
-
print(f"*** CorpCompile redis cashing top{top} items ***")
|
351
|
-
for i, (code, _) in enumerate(ranking_topn.items()):
|
352
|
-
corp_compile.code = code
|
353
|
-
print(f"{i + 1}. {code}")
|
354
|
-
corp_compile_data = corp_compile.get(refresh=refresh)
|
355
|
-
print(corp_compile_data)
|
File without changes
|
@@ -1,50 +0,0 @@
|
|
1
|
-
import yfinance as yf
|
2
|
-
import numpy as np
|
3
|
-
import pandas as pd
|
4
|
-
from sklearn.model_selection import train_test_split
|
5
|
-
from sklearn.linear_model import LinearRegression
|
6
|
-
import matplotlib.pyplot as plt
|
7
|
-
|
8
|
-
# 1. 데이터 다운로드 (애플 주식 데이터를 사용)
|
9
|
-
# 데이터 기간: 2020년 1월 1일부터 2023년 1월 1일까지
|
10
|
-
#stock_data = yf.download('AAPL', start='2020-01-01', end='2023-01-01')
|
11
|
-
# 삼성전자 주식 데이터 가져오기 (KOSPI 상장)
|
12
|
-
stock_data = yf.download('005930.KS', start='2020-01-01', end='2024-08-01')
|
13
|
-
# 크래프톤 주식 데이터 가져오기 (KOSPI 상장)
|
14
|
-
#stock_data = yf.download('259960.KS', start='2020-01-01', end='2024-10-08')
|
15
|
-
|
16
|
-
# 2. 필요한 열만 선택 (종가만 사용)
|
17
|
-
df = stock_data[['Close']]
|
18
|
-
|
19
|
-
# 3. 주가 데이터를 시계열 데이터로 변환하여 예측
|
20
|
-
# 일자를 숫자로 변환 (날짜 자체는 예측 모델에 사용하기 어렵기 때문에 숫자로 변환)
|
21
|
-
df['Date'] = np.arange(len(df))
|
22
|
-
|
23
|
-
# 4. 독립 변수(X)와 종속 변수(y) 분리
|
24
|
-
X = df[['Date']] # 독립 변수 (날짜)
|
25
|
-
y = df['Close'] # 종속 변수 (주가)
|
26
|
-
|
27
|
-
# 5. 데이터를 학습용과 테스트용으로 분리 (80% 학습, 20% 테스트)
|
28
|
-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
29
|
-
|
30
|
-
# 6. 선형 회귀 모델 생성 및 학습
|
31
|
-
model = LinearRegression()
|
32
|
-
model.fit(X_train, y_train)
|
33
|
-
|
34
|
-
# 7. 테스트 데이터를 사용하여 예측 수행
|
35
|
-
y_pred = model.predict(X_test)
|
36
|
-
|
37
|
-
# 8. 결과 시각화
|
38
|
-
plt.figure(figsize=(10, 6))
|
39
|
-
plt.scatter(X_train, y_train, color='blue', label='Training data') # 학습 데이터
|
40
|
-
plt.scatter(X_test, y_test, color='green', label='Test data') # 실제 테스트 데이터
|
41
|
-
plt.plot(X_test, y_pred, color='red', label='Predicted price') # 예측된 주가
|
42
|
-
plt.xlabel('Date (numeric)')
|
43
|
-
plt.ylabel('Stock Price (Close)')
|
44
|
-
plt.legend()
|
45
|
-
plt.title('Apple Stock Price Prediction')
|
46
|
-
plt.show()
|
47
|
-
|
48
|
-
# 9. 모델 평가 (R^2 스코어)
|
49
|
-
r2_score = model.score(X_test, y_test)
|
50
|
-
print(f"모델의 R^2 스코어: {r2_score:.2f}")
|
@@ -1,39 +0,0 @@
|
|
1
|
-
# 필요한 라이브러리 불러오기
|
2
|
-
import numpy as np
|
3
|
-
from sklearn.linear_model import LinearRegression
|
4
|
-
from sklearn.model_selection import train_test_split
|
5
|
-
import matplotlib.pyplot as plt
|
6
|
-
|
7
|
-
# 1. 데이터 준비 (주택 면적, 가격)
|
8
|
-
# 예를 들어 면적에 따른 주택 가격 데이터 (면적: X, 가격: y)
|
9
|
-
X = np.array([[1500], [2000], [2500], [3000], [3500], [4000]]) # 면적 (단위: square feet)
|
10
|
-
y = np.array([300000, 400000, 500000, 600000, 700000, 800000]) # 가격 (단위: dollars)
|
11
|
-
|
12
|
-
# 2. 학습 데이터와 테스트 데이터를 나누기 (80% 학습, 20% 테스트)
|
13
|
-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
14
|
-
|
15
|
-
# 3. 선형 회귀 모델 생성
|
16
|
-
model = LinearRegression()
|
17
|
-
|
18
|
-
# 4. 모델을 학습시키기 (train 데이터를 사용)
|
19
|
-
model.fit(X_train, y_train)
|
20
|
-
|
21
|
-
# 5. 테스트 데이터로 예측 수행
|
22
|
-
y_pred = model.predict(X_test)
|
23
|
-
|
24
|
-
# 6. 예측 결과 출력
|
25
|
-
print("실제 값:", y_test)
|
26
|
-
print("예측 값:", y_pred)
|
27
|
-
|
28
|
-
# 7. 시각화를 통해 학습 결과 확인
|
29
|
-
plt.scatter(X_train, y_train, color='blue', label='Training data') # 학습 데이터
|
30
|
-
plt.scatter(X_test, y_test, color='green', label='Test data') # 실제 값
|
31
|
-
plt.plot(X_test, y_pred, color='red', label='Prediction') # 예측된 값
|
32
|
-
plt.xlabel('House Size (square feet)')
|
33
|
-
plt.ylabel('Price (dollars)')
|
34
|
-
plt.legend()
|
35
|
-
plt.show()
|
36
|
-
|
37
|
-
# 9. 모델 평가 (R^2 스코어)
|
38
|
-
r2_score = model.score(X_test, y_test)
|
39
|
-
print(f"모델의 R^2 스코어: {r2_score:.2f}")
|