pro-craft 0.2.57__py3-none-any.whl → 0.2.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pro-craft might be problematic. Click here for more details.
- pro_craft/__init__.py +2 -2
- pro_craft/code_helper/__init__.py +0 -0
- pro_craft/code_helper/agent.py +90 -0
- pro_craft/code_helper/codermanager.py +143 -0
- pro_craft/code_helper/database.py +36 -0
- pro_craft/code_helper/paper_program.py +183 -0
- pro_craft/code_helper/template_extract.py +134 -0
- pro_craft/code_helper/tools.py +113 -0
- pro_craft/code_helper/vectorstore.py +81 -0
- pro_craft/code_helper/write_code.py +61 -0
- pro_craft/database.py +5 -6
- pro_craft/log.py +4 -3
- pro_craft/prompt_craft/async_.py +188 -561
- pro_craft/utils.py +1 -1
- {pro_craft-0.2.57.dist-info → pro_craft-0.2.58.dist-info}/METADATA +7 -1
- pro_craft-0.2.58.dist-info/RECORD +28 -0
- pro_craft/code_helper/coder.py +0 -660
- pro_craft/code_helper/designer.py +0 -115
- pro_craft-0.2.57.dist-info/RECORD +0 -21
- {pro_craft-0.2.57.dist-info → pro_craft-0.2.58.dist-info}/WHEEL +0 -0
- {pro_craft-0.2.57.dist-info → pro_craft-0.2.58.dist-info}/top_level.txt +0 -0
pro_craft/prompt_craft/async_.py
CHANGED
|
@@ -1,27 +1,22 @@
|
|
|
1
1
|
# 测试1
|
|
2
2
|
from pro_craft.utils import extract_
|
|
3
|
-
from pro_craft import logger as pro_craft_logger
|
|
4
3
|
from modusched.core import BianXieAdapter, ArkAdapter
|
|
5
|
-
from datetime import datetime
|
|
6
4
|
from enum import Enum
|
|
7
5
|
import functools
|
|
8
6
|
import json
|
|
9
7
|
import os
|
|
10
|
-
from pro_craft.database import Prompt, UseCase, PromptBase
|
|
8
|
+
from pro_craft.database import Prompt, UseCase, PromptBase, SyncMetadata
|
|
11
9
|
from pro_craft.utils import create_session, create_async_session
|
|
12
10
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine # 异步核心
|
|
13
11
|
from sqlalchemy import select, delete # 导入 select, delete 用于异步操作
|
|
14
12
|
import inspect
|
|
15
|
-
from datetime import datetime
|
|
16
13
|
from pro_craft.utils import extract_
|
|
17
14
|
import asyncio
|
|
18
15
|
import re
|
|
19
16
|
from pydantic import BaseModel, ValidationError, field_validator
|
|
20
17
|
from sqlalchemy import select, desc
|
|
21
18
|
from json.decoder import JSONDecodeError
|
|
22
|
-
|
|
23
|
-
from datetime import datetime, timedelta
|
|
24
|
-
from datetime import datetime, timedelta
|
|
19
|
+
|
|
25
20
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
26
21
|
from sqlalchemy import select, and_ # 引入 select 和 and_
|
|
27
22
|
from sqlalchemy.orm import class_mapper # 用于检查对象是否是持久化的
|
|
@@ -31,6 +26,11 @@ from tqdm.asyncio import tqdm
|
|
|
31
26
|
import pandas as pd
|
|
32
27
|
import plotly.graph_objects as go
|
|
33
28
|
|
|
29
|
+
def get_log_info(target: str, val = None):
|
|
30
|
+
return f"{target} & {type(val)} & {val}"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
34
|
BATCH_SIZE = int(os.getenv("DATABASE_SYNC_BATCH_SIZE",1000))
|
|
35
35
|
|
|
36
36
|
def fix_broken_json_string(broken_json_str):
|
|
@@ -112,7 +112,7 @@ class AsyncIntel():
|
|
|
112
112
|
logger = None,
|
|
113
113
|
):
|
|
114
114
|
database_url = database_url or os.getenv("database_url")
|
|
115
|
-
self.logger = logger
|
|
115
|
+
self.logger = logger
|
|
116
116
|
try:
|
|
117
117
|
assert database_url
|
|
118
118
|
assert 'aio' in database_url
|
|
@@ -156,29 +156,6 @@ class AsyncIntel():
|
|
|
156
156
|
async with engine.begin() as conn:
|
|
157
157
|
await conn.run_sync(PromptBase.metadata.create_all)
|
|
158
158
|
|
|
159
|
-
async def get_prompt(self,prompt_id,version,session):
|
|
160
|
-
"""
|
|
161
|
-
获取指定 prompt_id 的最新版本数据,通过创建时间判断。
|
|
162
|
-
"""
|
|
163
|
-
if version:
|
|
164
|
-
stmt_ = select(Prompt).filter(
|
|
165
|
-
Prompt.prompt_id == prompt_id,
|
|
166
|
-
Prompt.version == version
|
|
167
|
-
)
|
|
168
|
-
else:
|
|
169
|
-
stmt_ = select(Prompt).filter(
|
|
170
|
-
Prompt.prompt_id == prompt_id,
|
|
171
|
-
)
|
|
172
|
-
stmt = stmt_.order_by(
|
|
173
|
-
desc(Prompt.timestamp), # 使用 sqlalchemy.desc() 来指定降序
|
|
174
|
-
desc(Prompt.version) # 使用 sqlalchemy.desc() 来指定降序
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
result = await session.execute(stmt)
|
|
178
|
-
result = result.scalars().first()
|
|
179
|
-
|
|
180
|
-
return result
|
|
181
|
-
|
|
182
159
|
async def sync_production_database(self,database_url:str):
|
|
183
160
|
target_engine = create_async_engine(database_url, echo=False)
|
|
184
161
|
await self.create_database(target_engine)
|
|
@@ -272,7 +249,29 @@ class AsyncIntel():
|
|
|
272
249
|
else:
|
|
273
250
|
print("No new records to sync.")
|
|
274
251
|
|
|
252
|
+
async def get_prompt(self,prompt_id,version,session):
|
|
253
|
+
"""
|
|
254
|
+
获取指定 prompt_id 的最新版本数据,通过创建时间判断。
|
|
255
|
+
"""
|
|
256
|
+
if version:
|
|
257
|
+
stmt_ = select(Prompt).filter(
|
|
258
|
+
Prompt.prompt_id == prompt_id,
|
|
259
|
+
Prompt.version == version
|
|
260
|
+
)
|
|
261
|
+
else:
|
|
262
|
+
stmt_ = select(Prompt).filter(
|
|
263
|
+
Prompt.prompt_id == prompt_id,
|
|
264
|
+
)
|
|
265
|
+
stmt = stmt_.order_by(
|
|
266
|
+
desc(Prompt.timestamp), # 使用 sqlalchemy.desc() 来指定降序
|
|
267
|
+
desc(Prompt.version) # 使用 sqlalchemy.desc() 来指定降序
|
|
268
|
+
)
|
|
275
269
|
|
|
270
|
+
result = await session.execute(stmt)
|
|
271
|
+
result = result.scalars().first()
|
|
272
|
+
|
|
273
|
+
return result
|
|
274
|
+
|
|
276
275
|
async def get_prompt_safe(self,
|
|
277
276
|
prompt_id: str,
|
|
278
277
|
version = None,
|
|
@@ -283,11 +282,13 @@ class AsyncIntel():
|
|
|
283
282
|
prompt_obj = await self.get_prompt(prompt_id=prompt_id,version=version,session=session)
|
|
284
283
|
if prompt_obj:
|
|
285
284
|
return prompt_obj
|
|
285
|
+
if version:
|
|
286
|
+
prompt_obj = await self.get_prompt(prompt_id=prompt_id,version=None,session=session)
|
|
286
287
|
|
|
287
|
-
prompt_obj
|
|
288
|
+
if prompt_obj is None:
|
|
289
|
+
raise IntellectRemoveError("不存在的prompt_id")
|
|
288
290
|
return prompt_obj
|
|
289
291
|
|
|
290
|
-
|
|
291
292
|
async def save_prompt(self,
|
|
292
293
|
prompt_id: str,
|
|
293
294
|
new_prompt: str,
|
|
@@ -328,375 +329,16 @@ class AsyncIntel():
|
|
|
328
329
|
session.add(prompt1)
|
|
329
330
|
await session.commit() # 提交事务,将数据写入数据库
|
|
330
331
|
|
|
331
|
-
async def
|
|
332
|
-
target_prompt_id: str,
|
|
333
|
-
start_time: datetime = None, # 新增:开始时间
|
|
334
|
-
end_time: datetime = None, # 新增:结束时间
|
|
335
|
-
session = None
|
|
336
|
-
):
|
|
337
|
-
"""
|
|
338
|
-
从sql保存提示词
|
|
339
|
-
"""
|
|
340
|
-
stmt = select(UseCase).filter(UseCase.is_deleted == 0,
|
|
341
|
-
UseCase.prompt_id == target_prompt_id)
|
|
342
|
-
|
|
343
|
-
if start_time:
|
|
344
|
-
stmt = stmt.filter(UseCase.created_at >= start_time) # 假设你的UseCase模型有一个created_at字段
|
|
345
|
-
|
|
346
|
-
if end_time:
|
|
347
|
-
stmt = stmt.filter(UseCase.created_at <= end_time)
|
|
348
|
-
result = await session.execute(stmt)
|
|
349
|
-
# use_case = result.scalars().one_or_none()
|
|
350
|
-
use_case = result.scalars().all()
|
|
351
|
-
return use_case
|
|
352
|
-
|
|
353
|
-
async def save_use_case2(self,session = None):
|
|
354
|
-
with open("/Users/zhaoxuefeng/GitHub/digital_life/logs/app.log",'r') as f:
|
|
355
|
-
x = f.read()
|
|
356
|
-
|
|
357
|
-
def work(resu):
|
|
358
|
-
if len(resu) == 9:
|
|
359
|
-
unix_timestamp_str = resu[2]
|
|
360
|
-
dt_object = datetime.datetime.fromtimestamp(float(unix_timestamp_str.strip()))
|
|
361
|
-
use_case = UseCase(
|
|
362
|
-
time = resu[2],
|
|
363
|
-
level = resu[1],
|
|
364
|
-
timestamp =dt_object.strftime('%Y-%m-%d %H:%M:%S.%f'),
|
|
365
|
-
filepath=resu[3],
|
|
366
|
-
function=resu[4],
|
|
367
|
-
lines=resu[5],
|
|
368
|
-
type_=resu[6],
|
|
369
|
-
target=resu[7],
|
|
370
|
-
content=resu[8],
|
|
371
|
-
)
|
|
372
|
-
session.add(use_case)
|
|
373
|
-
else:
|
|
374
|
-
print(len(resu))
|
|
375
|
-
print(resu,'resu')
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
for res in x.split("||"):
|
|
379
|
-
resu = res.split("$")
|
|
380
|
-
work(resu)
|
|
381
|
-
|
|
382
|
-
await session.commit() # 提交事务,将数据写入数据库
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
async def save_use_case3(self, session):
|
|
386
|
-
log_filepath = "/Users/zhaoxuefeng/GitHub/digital_life/logs/app.log"
|
|
387
|
-
|
|
388
|
-
# 1. 获取数据库中已有的最新时间戳
|
|
389
|
-
# 假设 timestamp 列是 DATETIME 类型,且是用于唯一标识和排序的关键字段
|
|
390
|
-
latest_db_timestamp = None
|
|
391
|
-
try:
|
|
392
|
-
result = await session.execute(
|
|
393
|
-
select(UseCase.timestamp)
|
|
394
|
-
.order_by(UseCase.timestamp.desc())
|
|
395
|
-
.limit(1)
|
|
396
|
-
)
|
|
397
|
-
latest_db_timestamp = result.scalar_one_or_none()
|
|
398
|
-
if latest_db_timestamp:
|
|
399
|
-
print(f"Latest timestamp in DB: {latest_db_timestamp}")
|
|
400
|
-
else:
|
|
401
|
-
print("No records found in DB. Starting fresh.")
|
|
402
|
-
except Exception as e:
|
|
403
|
-
print(f"Error querying latest timestamp: {e}")
|
|
404
|
-
# 如果查询失败,可以选择继续,但可能导致重复导入,或者直接退出
|
|
405
|
-
return
|
|
406
|
-
|
|
407
|
-
added_count = 0
|
|
408
|
-
skipped_count = 0
|
|
409
|
-
error_count = 0
|
|
410
|
-
|
|
411
|
-
# 2. 读取并处理日志文件
|
|
412
|
-
try:
|
|
413
|
-
with open(log_filepath, 'r') as f:
|
|
414
|
-
x = f.read()
|
|
415
|
-
except FileNotFoundError:
|
|
416
|
-
print(f"Error: Log file not found at {log_filepath}")
|
|
417
|
-
return
|
|
418
|
-
|
|
419
|
-
# 日志记录通常是逐行添加的,所以倒序处理可能更高效,但也取决于文件大小和格式
|
|
420
|
-
# 对于你当前的分隔符格式,还是顺序处理比较直接
|
|
421
|
-
for res_str in x.split("||"):
|
|
422
|
-
if not res_str.strip(): # 跳过空字符串
|
|
423
|
-
continue
|
|
424
|
-
|
|
425
|
-
# 使用 try-except 块来处理可能的解析错误
|
|
426
|
-
try:
|
|
427
|
-
resu = res_str.split("$")
|
|
428
|
-
# 检查字段数量是否正确
|
|
429
|
-
# 你的原始代码期望 len(resu) == 9, 但是 SQL 语句有 10 个字段,
|
|
430
|
-
# UseCase 构造函数也有 9 个参数,这需要对应起来
|
|
431
|
-
# level, time, timestamp, filepath, function, lines, type_, target, content, is_deleted
|
|
432
|
-
# 对应 resu 的 index: [1], [2], [2], [3], [4], [5], [6], [7], [8]
|
|
433
|
-
# 看起来 resu[0] 是空的或者不用的,所以 resu 至少需要有 9 个元素(索引0-8)
|
|
434
|
-
# 因此,判断 len(resu) >= 9 即可
|
|
435
|
-
if len(resu) < 9:
|
|
436
|
-
print(f"Skipping malformed log entry (not enough fields): {res_str}")
|
|
437
|
-
error_count += 1
|
|
438
|
-
continue
|
|
439
|
-
|
|
440
|
-
# 提取并清理原始数据
|
|
441
|
-
level_raw = resu[0].strip() # 假设 resu[0] 是 level
|
|
442
|
-
time_raw = resu[1].strip() # 假设 resu[1] 是 time (原始日志时间字符串)
|
|
443
|
-
timestamp_raw = resu[2].strip() # 假设 resu[2] 是 timestamp (unix时间戳字符串)
|
|
444
|
-
filepath_raw = resu[3].strip()
|
|
445
|
-
function_raw = resu[4].strip()
|
|
446
|
-
lines_raw = resu[5].strip()
|
|
447
|
-
type_raw = resu[6].strip()
|
|
448
|
-
target_raw = resu[7].strip()
|
|
449
|
-
content_raw = resu[8].strip()
|
|
450
|
-
|
|
451
|
-
# 处理 time 字段 (原始日志时间字符串)
|
|
452
|
-
# 上次错误是 time 列太短,并且有换行符。在这里清理一下。
|
|
453
|
-
# 假设 time 字段就是你日志中 YYYY-MM-DD HH:MM:SS,ms 这种格式,
|
|
454
|
-
# 但你上次给的示例是 '\n2025-11-01 11:17:52,029 ',所以需要清理
|
|
455
|
-
processed_time_str = time_raw.replace('\n', '').strip()
|
|
456
|
-
# 如果数据库 time 列是 VARCHAR,确保长度够用
|
|
457
|
-
# 如果是 DATETIME,你需要解析它
|
|
458
|
-
# 例如:
|
|
459
|
-
# try:
|
|
460
|
-
# dt_obj_from_time = datetime.datetime.strptime(processed_time_str, '%Y-%m-%d %H:%M:%S,%f')
|
|
461
|
-
# # 再次格式化,确保数据库兼容
|
|
462
|
-
# processed_time_str = dt_obj_from_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
|
|
463
|
-
# except ValueError:
|
|
464
|
-
# print(f"Warning: Could not parse 'time' string: {processed_time_str}. Using raw string.")
|
|
465
|
-
# # 或者跳过此条记录,或者设置为 None
|
|
466
|
-
|
|
467
|
-
# 处理 timestamp 字段 (Unix时间戳转换为 datetime 对象)
|
|
468
|
-
try:
|
|
469
|
-
unix_timestamp_float = float(timestamp_raw)
|
|
470
|
-
dt_object = datetime.datetime.fromtimestamp(unix_timestamp_float)
|
|
471
|
-
# 格式化为 MySQL DATETIME(6) 兼容的字符串(包含微秒)
|
|
472
|
-
formatted_timestamp = dt_object.strftime('%Y-%m-%d %H:%M:%S.%f')
|
|
473
|
-
except ValueError:
|
|
474
|
-
print(f"Skipping malformed log entry (invalid timestamp float): {res_str}")
|
|
475
|
-
error_count += 1
|
|
476
|
-
continue
|
|
477
|
-
|
|
478
|
-
# 3. 比较时间戳进行增量检查
|
|
479
|
-
# 将格式化后的时间戳字符串转换为 datetime 对象进行比较
|
|
480
|
-
current_log_timestamp = datetime.datetime.strptime(formatted_timestamp, '%Y-%m-%d %H:%M:%S.%f')
|
|
481
|
-
|
|
482
|
-
if latest_db_timestamp and current_log_timestamp <= latest_db_timestamp:
|
|
483
|
-
# print(f"Skipping existing log entry (timestamp: {current_log_timestamp})")
|
|
484
|
-
skipped_count += 1
|
|
485
|
-
continue # 跳过已存在的或旧的记录
|
|
486
|
-
|
|
487
|
-
# 创建 UseCase 实例
|
|
488
|
-
use_case = UseCase(
|
|
489
|
-
time=processed_time_str, # 使用清理后的原始时间字符串
|
|
490
|
-
level=level_raw,
|
|
491
|
-
timestamp=current_log_timestamp, # 传入 datetime 对象
|
|
492
|
-
filepath=filepath_raw,
|
|
493
|
-
function=function_raw,
|
|
494
|
-
lines=lines_raw,
|
|
495
|
-
type_=type_raw,
|
|
496
|
-
target=target_raw,
|
|
497
|
-
content=content_raw,
|
|
498
|
-
is_deleted=False, # 默认值
|
|
499
|
-
)
|
|
500
|
-
session.add(use_case)
|
|
501
|
-
added_count += 1
|
|
502
|
-
|
|
503
|
-
except Exception as e:
|
|
504
|
-
print(f"Error processing log entry: {res_str}. Error: {e}")
|
|
505
|
-
error_count += 1
|
|
506
|
-
session.rollback() # 如果在添加过程中发生错误,回滚当前批次,避免污染 session
|
|
507
|
-
# 重新开始一个新的事务,或者处理这个错误
|
|
508
|
-
|
|
509
|
-
# 4. 提交事务
|
|
510
|
-
try:
|
|
511
|
-
await session.commit()
|
|
512
|
-
print(f"Log processing complete: Added {added_count} new entries, skipped {skipped_count} existing entries, encountered {error_count} errors.")
|
|
513
|
-
except Exception as e:
|
|
514
|
-
print(f"Error during final commit: {e}")
|
|
515
|
-
await session.rollback()
|
|
516
|
-
|
|
517
|
-
async def save_use_case(self,
|
|
518
|
-
prompt_id: str,
|
|
519
|
-
use_case:str = "",
|
|
520
|
-
timestamp = "",
|
|
521
|
-
output = "",
|
|
522
|
-
solution: str = "",
|
|
523
|
-
faired_time = 0,
|
|
524
|
-
session = None
|
|
525
|
-
):
|
|
526
|
-
|
|
527
|
-
"""
|
|
528
|
-
从sql保存提示词
|
|
529
|
-
"""
|
|
530
|
-
#TODO 存之前保证数据库中相同的prompt_id中没有重复的use_case
|
|
531
|
-
use_cases = await self.get_use_case(target_prompt_id = prompt_id,
|
|
532
|
-
session = session)
|
|
533
|
-
for use_case_old in use_cases:
|
|
534
|
-
if use_case == use_case_old.use_case:
|
|
535
|
-
# print("用例已经存在")
|
|
536
|
-
return
|
|
537
|
-
#time,level, timestamp filepath, function lines, type_, target, content
|
|
538
|
-
with open("/Users/zhaoxuefeng/GitHub/digital_life/logs/app.log",'r') as f:
|
|
539
|
-
x = f.read()
|
|
540
|
-
resu = x.split("||")[14].split("$")
|
|
541
|
-
|
|
542
|
-
use_case = UseCase(
|
|
543
|
-
time = resu[0],
|
|
544
|
-
level = resu[1],
|
|
545
|
-
timestamp =resu[2],
|
|
546
|
-
filepath=resu[3],
|
|
547
|
-
function=resu[4],
|
|
548
|
-
lines=resu[5],
|
|
549
|
-
type_=resu[6],
|
|
550
|
-
target=resu[7],
|
|
551
|
-
content=resu[8],
|
|
552
|
-
)
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
# use_case = UseCase(prompt_id=prompt_id,
|
|
557
|
-
# use_case = use_case,
|
|
558
|
-
# timestamp = timestamp,
|
|
559
|
-
# output = output,
|
|
560
|
-
# solution = solution,
|
|
561
|
-
# faired_time = faired_time,
|
|
562
|
-
# )
|
|
563
|
-
|
|
564
|
-
session.add(use_case)
|
|
565
|
-
await session.commit() # 提交事务,将数据写入数据库
|
|
566
|
-
|
|
567
|
-
async def push_action_order(self,demand : str,prompt_id: str,
|
|
568
|
-
action_type = 'train'):# init
|
|
569
|
-
|
|
570
|
-
"""
|
|
571
|
-
从sql保存提示词
|
|
572
|
-
推一个train 状态到指定的位置
|
|
332
|
+
async def adjust_prompt(self,prompt_id: str,action_type = "summary", demand: str = ""):
|
|
573
333
|
|
|
574
|
-
将打算修改的状态推上数据库 # 1
|
|
575
|
-
"""
|
|
576
|
-
# 查看是否已经存在
|
|
577
|
-
async with create_async_session(self.engine) as session:
|
|
578
|
-
|
|
579
|
-
latest_prompt = await self.get_prompt_safe(prompt_id=prompt_id,session=session)
|
|
580
|
-
if latest_prompt:
|
|
581
|
-
await self.save_prompt(prompt_id=latest_prompt.prompt_id,
|
|
582
|
-
new_prompt = latest_prompt.prompt,
|
|
583
|
-
use_case = latest_prompt.use_case,
|
|
584
|
-
action_type=action_type,
|
|
585
|
-
demand=demand,
|
|
586
|
-
score=latest_prompt.score,
|
|
587
|
-
session=session
|
|
588
|
-
)
|
|
589
|
-
return "success"
|
|
590
|
-
else:
|
|
591
|
-
await self.save_prompt(prompt_id=prompt_id,
|
|
592
|
-
new_prompt = demand,
|
|
593
|
-
use_case = "",
|
|
594
|
-
action_type="inference",
|
|
595
|
-
demand=demand,
|
|
596
|
-
score=60,
|
|
597
|
-
session=session
|
|
598
|
-
)
|
|
599
|
-
return "init"
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
async def intellect(self,
|
|
603
|
-
input_data: dict | str,
|
|
604
|
-
output_format: str,
|
|
605
|
-
prompt_id: str,
|
|
606
|
-
version: str = None,
|
|
607
|
-
change_case = False,
|
|
608
|
-
):
|
|
609
|
-
"""
|
|
610
|
-
自定自动化执行命令的方法,
|
|
611
|
-
不涉及严格的校验, 主要职能在自动化的修改提示词, 或者管理提示词上
|
|
612
|
-
"""
|
|
613
|
-
if isinstance(input_data,dict):
|
|
614
|
-
input_ = json.dumps(input_data,ensure_ascii=False)
|
|
615
|
-
elif isinstance(input_data,str):
|
|
616
|
-
input_ = input_data
|
|
617
|
-
|
|
618
334
|
# 查数据库, 获取最新提示词对象
|
|
619
335
|
async with create_async_session(self.engine) as session:
|
|
620
336
|
result_obj = await self.get_prompt_safe(prompt_id=prompt_id,session=session)
|
|
621
|
-
if result_obj is None:
|
|
622
|
-
raise IntellectRemoveError("不存在的prompt_id")
|
|
623
337
|
|
|
624
338
|
prompt = result_obj.prompt
|
|
625
|
-
|
|
626
|
-
# 直接推理即可
|
|
627
|
-
ai_result = await self.llm.aproduct(prompt + output_format + "\nuser:" + input_)
|
|
628
|
-
|
|
629
|
-
elif result_obj.action_type == "train":
|
|
630
|
-
assert result_obj.demand # 如果type = train 且 demand 是空 则报错
|
|
631
|
-
# 则训练推广
|
|
632
|
-
|
|
633
|
-
# 新版本 默人修改会 inference 状态
|
|
634
|
-
|
|
339
|
+
use_case = result_obj.use_case
|
|
635
340
|
|
|
636
|
-
|
|
637
|
-
# # 注意, 这里的调整要求使用最初的那个输入, 最好一口气调整好
|
|
638
|
-
# chat_history = prompt
|
|
639
|
-
# if input_ == before_input: # 输入没变, 说明还是针对同一个输入进行讨论
|
|
640
|
-
# # input_prompt = chat_history + "\nuser:" + demand
|
|
641
|
-
# input_prompt = chat_history + "\nuser:" + demand + output_format
|
|
642
|
-
# else:
|
|
643
|
-
# # input_prompt = chat_history + "\nuser:" + demand + "\n-----input----\n" + input_
|
|
644
|
-
# input_prompt = chat_history + "\nuser:" + demand + output_format + "\n-----input----\n" + input_
|
|
645
|
-
|
|
646
|
-
# ai_result = await self.llm.aproduct(input_prompt)
|
|
647
|
-
# chat_history = input_prompt + "\nassistant:\n" + ai_result # 用聊天记录作为完整提示词
|
|
648
|
-
# await self.save_prompt(prompt_id, chat_history,
|
|
649
|
-
# use_case = input_,
|
|
650
|
-
# score = 60,
|
|
651
|
-
# session = session)
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
# version 2
|
|
655
|
-
|
|
656
|
-
# if input_ == before_input:
|
|
657
|
-
# new_prompt = prompt + "\nuser:" + demand
|
|
658
|
-
# else:
|
|
659
|
-
# new_prompt = prompt + "\nuser:" + input_
|
|
660
|
-
|
|
661
|
-
# ai_result = await self.llm.aproduct(new_prompt + output_format)
|
|
662
|
-
|
|
663
|
-
# save_new_prompt = new_prompt + "\nassistant:\n" + ai_result
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
# await self.save_prompt(
|
|
667
|
-
# prompt_id,
|
|
668
|
-
# new_prompt=save_new_prompt,
|
|
669
|
-
# use_case = input_,
|
|
670
|
-
# action_type = "inference",
|
|
671
|
-
# score = 60,
|
|
672
|
-
# session = session)
|
|
673
|
-
chat_history = prompt
|
|
674
|
-
before_input = result_obj.use_case
|
|
675
|
-
demand = result_obj.demand
|
|
676
|
-
input_data = input_
|
|
677
|
-
if before_input == "" or change_case is True:
|
|
678
|
-
result_obj.use_case = input_
|
|
679
|
-
await session.commit()
|
|
680
|
-
# 查询上一条, 将before_input 更新位input_
|
|
681
|
-
prompt += input_
|
|
682
|
-
|
|
683
|
-
# 使用更新后的数据进行后续步骤
|
|
684
|
-
new_prompt = prompt + "\nuser:" + demand
|
|
685
|
-
|
|
686
|
-
ai_result = await self.llm.aproduct(new_prompt + output_format)
|
|
687
|
-
|
|
688
|
-
save_new_prompt = new_prompt + "\nassistant:\n" + ai_result
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
await self.save_prompt(
|
|
692
|
-
prompt_id,
|
|
693
|
-
new_prompt=save_new_prompt,
|
|
694
|
-
use_case = input_,
|
|
695
|
-
action_type = "inference",
|
|
696
|
-
score = 60,
|
|
697
|
-
session = session)
|
|
698
|
-
|
|
699
|
-
elif result_obj.action_type == "summary":
|
|
341
|
+
if action_type == "summary":
|
|
700
342
|
system_prompt_summary = """
|
|
701
343
|
很棒, 我们已经达成了某种默契, 我们之间合作无间, 但是, 可悲的是, 当我关闭这个窗口的时候, 你就会忘记我们之间经历的种种磨合, 这是可惜且心痛的, 所以你能否将目前这一套处理流程结晶成一个优质的prompt 这样, 我们下一次只要将prompt输入, 你就能想起我们今天的磨合过程,
|
|
702
344
|
对了,我提示一点, 这个prompt的主角是你, 也就是说, 你在和未来的你对话, 你要教会未来的你今天这件事, 是否让我看懂到时其次
|
|
@@ -704,24 +346,12 @@ class AsyncIntel():
|
|
|
704
346
|
只要输出提示词内容即可, 不需要任何的说明和解释
|
|
705
347
|
"""
|
|
706
348
|
|
|
707
|
-
latest_prompt = await self.get_prompt_safe(prompt_id=prompt_id,session=session)
|
|
708
|
-
|
|
709
349
|
system_result = await self.llm.aproduct(prompt + system_prompt_summary)
|
|
710
350
|
s_prompt = extract_(system_result,pattern_key=r"prompt")
|
|
711
351
|
new_prompt = s_prompt or system_result
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
use_case = latest_prompt.use_case,
|
|
716
|
-
score = 65,
|
|
717
|
-
action_type = "inference",
|
|
718
|
-
session = session
|
|
719
|
-
)
|
|
720
|
-
|
|
721
|
-
ai_result = await self.llm.aproduct(prompt + output_format + "\nuser:" + input_)
|
|
722
|
-
|
|
723
|
-
elif result_obj.action_type == "finetune":
|
|
724
|
-
demand = result_obj.demand
|
|
352
|
+
|
|
353
|
+
elif action_type == "finetune":
|
|
354
|
+
assert demand
|
|
725
355
|
change_by_opinion_prompt = """
|
|
726
356
|
你是一个资深AI提示词工程师,具备卓越的Prompt设计与优化能力。
|
|
727
357
|
我将为你提供一段现有System Prompt。你的核心任务是基于这段Prompt进行修改,以实现我提出的特定目标和功能需求。
|
|
@@ -748,76 +378,45 @@ class AsyncIntel():
|
|
|
748
378
|
功能需求:
|
|
749
379
|
{opinion}
|
|
750
380
|
"""
|
|
381
|
+
new_prompt = await self.llm.aproduct(
|
|
382
|
+
change_by_opinion_prompt.format(old_system_prompt=prompt, opinion=demand)
|
|
383
|
+
)
|
|
751
384
|
|
|
752
|
-
|
|
753
|
-
prompt_ = await self.get_prompt_safe(prompt_id = prompt_id,version = version,
|
|
754
|
-
session=session)
|
|
385
|
+
elif action_type == "patch":
|
|
755
386
|
assert demand
|
|
756
|
-
|
|
757
|
-
if demand:
|
|
758
|
-
new_prompt = await self.llm.aproduct(
|
|
759
|
-
change_by_opinion_prompt.format(old_system_prompt=prompt_.prompt, opinion=demand)
|
|
760
|
-
)
|
|
761
|
-
else:
|
|
762
|
-
new_prompt = prompt_
|
|
763
|
-
await self.save_prompt(
|
|
764
|
-
prompt_id,
|
|
765
|
-
new_prompt = new_prompt,
|
|
766
|
-
use_case = latest_prompt.use_case,
|
|
767
|
-
score = 70,
|
|
768
|
-
action_type = "inference",
|
|
769
|
-
session = session
|
|
770
|
-
)
|
|
771
|
-
|
|
772
|
-
ai_result = await self.llm.aproduct(prompt + output_format + "\nuser:" + input_)
|
|
387
|
+
new_prompt = prompt + "\n"+demand,
|
|
773
388
|
|
|
774
|
-
elif
|
|
775
|
-
demand = result_obj.demand
|
|
776
|
-
assert demand
|
|
777
|
-
latest_prompt = await self.get_prompt_safe(prompt_id=prompt_id,session=session)
|
|
778
|
-
|
|
779
|
-
chat_history = prompt + demand
|
|
780
|
-
await self.save_prompt(prompt_id,
|
|
781
|
-
chat_history,
|
|
782
|
-
use_case = latest_prompt.use_case,
|
|
783
|
-
score = 70,
|
|
784
|
-
action_type = "inference",
|
|
785
|
-
session = session)
|
|
786
|
-
|
|
787
|
-
ai_result = await self.llm.aproduct(chat_history + output_format + "\nuser:" + input_)
|
|
788
|
-
|
|
789
|
-
elif result_obj.action_type.startswith("to:"):
|
|
389
|
+
elif action_type.startswith("to:"):
|
|
790
390
|
target_version = result_obj.action_type.split(":")[-1]
|
|
791
|
-
latest_prompt = await self.get_prompt_safe(prompt_id=prompt_id,session=session)
|
|
792
391
|
prompt_obj = await self.get_prompt_safe(prompt_id=prompt_id,
|
|
793
392
|
version=target_version,
|
|
794
393
|
session=session)
|
|
795
394
|
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
use_case = latest_prompt.use_case,
|
|
799
|
-
score = prompt_obj.score,
|
|
800
|
-
action_type = "inference",
|
|
801
|
-
session = session)
|
|
802
|
-
ai_result = await self.llm.aproduct(prompt_obj.prompt + output_format + "\nuser:" + input_)
|
|
803
|
-
|
|
804
|
-
elif result_obj.action_type == "pass":
|
|
805
|
-
pass
|
|
806
|
-
|
|
395
|
+
new_prompt = prompt_obj.prompt
|
|
396
|
+
|
|
807
397
|
else:
|
|
808
398
|
raise
|
|
809
399
|
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
400
|
+
await self.save_prompt(
|
|
401
|
+
prompt_id,
|
|
402
|
+
new_prompt = new_prompt,
|
|
403
|
+
use_case = use_case,
|
|
404
|
+
score = 70,
|
|
405
|
+
action_type = "inference",
|
|
406
|
+
session = session
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
return "success"
|
|
410
|
+
|
|
411
|
+
async def inference_format(self,
|
|
813
412
|
input_data: dict | str,
|
|
814
|
-
OutputFormat: object | None,
|
|
815
413
|
prompt_id: str,
|
|
816
|
-
ExtraFormats: list[object] = [],
|
|
817
414
|
version: str = None,
|
|
415
|
+
OutputFormat: object | None = None,
|
|
416
|
+
ExtraFormats: list[object] = [],
|
|
818
417
|
ConTent_Function = None,
|
|
819
418
|
AConTent_Function = None,
|
|
820
|
-
|
|
419
|
+
again = True,
|
|
821
420
|
):
|
|
822
421
|
"""
|
|
823
422
|
这个format 是严格校验模式, 是interllect 的增强版, 会主动校验内容,并及时抛出异常(或者伺机修正)
|
|
@@ -834,134 +433,175 @@ class AsyncIntel():
|
|
|
834
433
|
"```json([\s\S]*?)```"
|
|
835
434
|
使用以下方式验证
|
|
836
435
|
"""
|
|
837
|
-
|
|
838
|
-
output_format = base_format_prompt + "\n".join([inspect.getsource(outputformat) for outputformat in ExtraFormats]) + inspect.getsource(OutputFormat)
|
|
839
|
-
else:
|
|
840
|
-
output_format = ""
|
|
436
|
+
assert isinstance(input_data,(dict,str))
|
|
841
437
|
|
|
842
|
-
if
|
|
843
|
-
|
|
438
|
+
input_ = json.dumps(input_data,ensure_ascii=False) if isinstance(input_data,dict) else input_data
|
|
439
|
+
output_format = base_format_prompt + "\n".join([inspect.getsource(outputformat) for outputformat in ExtraFormats]) + inspect.getsource(OutputFormat) if OutputFormat else ""
|
|
440
|
+
self.logger and self.logger.info(get_log_info("intel-输入",input_data))
|
|
844
441
|
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
442
|
+
async with create_async_session(self.engine) as session:
|
|
443
|
+
result_obj = await self.get_prompt_safe(prompt_id=prompt_id,version= version,
|
|
444
|
+
session=session)
|
|
445
|
+
prompt = result_obj.prompt
|
|
446
|
+
ai_result = await self.llm.aproduct(prompt + output_format + "\nuser:" + input_)
|
|
447
|
+
|
|
448
|
+
def check_json_valid(ai_result,OutputFormat):
|
|
852
449
|
try:
|
|
853
450
|
json_str = extract_(ai_result,r'json')
|
|
854
451
|
ai_result = json.loads(json_str)
|
|
855
452
|
OutputFormat(**ai_result)
|
|
856
453
|
|
|
857
454
|
except JSONDecodeError as e:
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
ai_result = json.loads(json_str)
|
|
862
|
-
OutputFormat(**ai_result)
|
|
863
|
-
|
|
864
|
-
except JSONDecodeError as e:
|
|
865
|
-
self.logger.error(f'{type(json_str)} $ {prompt_id}intellect生成的内容为无法被Json解析 $ {json_str}')
|
|
866
|
-
raise IntellectRemoveFormatError(f"prompt_id: {prompt_id} 生成的内容为无法被Json解析 {e}") from e
|
|
867
|
-
|
|
455
|
+
self.logger.error(f'{type(json_str)} $ {prompt_id}intellect生成的内容为无法被Json解析 $ {json_str}')
|
|
456
|
+
# raise IntellectRemoveFormatError(f"prompt_id: {prompt_id} 生成的内容为无法被Json解析 {e}") from e
|
|
457
|
+
return 0
|
|
868
458
|
except ValidationError as e:
|
|
869
459
|
err_info = e.errors()[0]
|
|
870
|
-
|
|
871
|
-
|
|
460
|
+
self.logger.error(f'{type(json_str)} $ {prompt_id}解析未通过OutputFormat $ {json_str}')
|
|
461
|
+
# raise IntellectRemoveFormatError(f"{err_info["type"]}: 属性:{err_info['loc']}, 发生了如下错误: {err_info['msg']}, 格式校验失败, 当前输入为: {err_info['input']} 请检查") from e
|
|
462
|
+
return 0
|
|
872
463
|
except Exception as e:
|
|
873
464
|
raise Exception(f"Error {prompt_id} : {e}") from e
|
|
465
|
+
return 1
|
|
874
466
|
|
|
875
|
-
if
|
|
467
|
+
if OutputFormat:
|
|
468
|
+
check_result = check_json_valid(ai_result,OutputFormat)
|
|
469
|
+
if check_result ==0 and again:
|
|
470
|
+
ai_result = await self.llm.aproduct(ai_result + output_format)
|
|
471
|
+
check_result_ = check_json_valid(ai_result,OutputFormat)
|
|
472
|
+
if check_result_ ==0:
|
|
473
|
+
raise IntellectRemoveFormatError(f"prompt_id: {prompt_id} 多次生成的内容均未通过OutputFormat校验, 当前内容为: {ai_result}")
|
|
474
|
+
json_str = extract_(ai_result,r'json')
|
|
475
|
+
ai_result = json.loads(json_str)
|
|
476
|
+
|
|
477
|
+
if ConTent_Function:# TODO
|
|
876
478
|
ConTent_Function(ai_result,input_data)
|
|
877
479
|
|
|
878
480
|
if AConTent_Function:
|
|
879
481
|
await AConTent_Function(ai_result,input_data)
|
|
880
482
|
|
|
881
|
-
|
|
882
|
-
logger.info(f'{type(ai_result)} $ intellect输出 ai_result $ {ai_result}')
|
|
483
|
+
self.logger and self.logger.info(f'{type(ai_result)} $ intellect输出 ai_result $ {ai_result}')
|
|
883
484
|
return ai_result
|
|
884
485
|
|
|
885
|
-
async def
|
|
486
|
+
async def inference_format_gather(self,
|
|
886
487
|
input_datas: list[dict | str],
|
|
887
|
-
OutputFormat: object | None,
|
|
888
488
|
prompt_id: str,
|
|
889
|
-
ExtraFormats: list[object] = [],
|
|
890
489
|
version: str = None,
|
|
490
|
+
OutputFormat: object | None = None,
|
|
491
|
+
ExtraFormats: list[object] = [],
|
|
891
492
|
**kwargs,
|
|
892
493
|
):
|
|
893
494
|
|
|
894
|
-
async with create_async_session(self.engine) as session:
|
|
895
|
-
prompt_result = await self.get_prompt_safe(prompt_id=prompt_id,
|
|
896
|
-
session=session)
|
|
897
|
-
if prompt_result is None:
|
|
898
|
-
raise IntellectRemoveError("不存在的prompt_id")
|
|
899
|
-
if prompt_result.action_type != "inference":
|
|
900
|
-
input_datas = input_datas[:1]
|
|
901
495
|
tasks = []
|
|
902
496
|
for input_data in input_datas:
|
|
903
497
|
tasks.append(
|
|
904
|
-
self.
|
|
498
|
+
self.inference_format(
|
|
905
499
|
input_data = input_data,
|
|
906
500
|
prompt_id = prompt_id,
|
|
501
|
+
version = version,
|
|
907
502
|
OutputFormat = OutputFormat,
|
|
908
503
|
ExtraFormats = ExtraFormats,
|
|
909
|
-
version = version,
|
|
910
504
|
**kwargs,
|
|
911
505
|
)
|
|
912
506
|
)
|
|
913
|
-
results = await
|
|
507
|
+
results = await tqdm.gather(*tasks,total=len(tasks))
|
|
508
|
+
# results = await asyncio.gather(*tasks, return_exceptions=False)
|
|
914
509
|
return results
|
|
915
510
|
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
async def get_use_case(self,
|
|
515
|
+
target_prompt_id: str,
|
|
516
|
+
start_time: datetime = None, # 新增:开始时间
|
|
517
|
+
end_time: datetime = None, # 新增:结束时间
|
|
518
|
+
session = None
|
|
519
|
+
):
|
|
520
|
+
"""
|
|
521
|
+
从sql保存提示词
|
|
522
|
+
"""
|
|
523
|
+
stmt = select(UseCase).filter(UseCase.is_deleted == 0,
|
|
524
|
+
UseCase.prompt_id == target_prompt_id)
|
|
525
|
+
|
|
526
|
+
if start_time:
|
|
527
|
+
stmt = stmt.filter(UseCase.timestamp >= start_time) # 假设你的UseCase模型有一个created_at字段
|
|
528
|
+
|
|
529
|
+
if end_time:
|
|
530
|
+
stmt = stmt.filter(UseCase.timestamp <= end_time)
|
|
531
|
+
result = await session.execute(stmt)
|
|
532
|
+
# use_case = result.scalars().one_or_none()
|
|
533
|
+
use_case = result.scalars().all()
|
|
534
|
+
return use_case
|
|
535
|
+
|
|
536
|
+
async def save_use_case(self,log_file,session = None):
|
|
537
|
+
with open(log_file,'r') as f:
|
|
538
|
+
x = f.read()
|
|
539
|
+
|
|
540
|
+
def deal_log(resu):
|
|
541
|
+
if len(resu) <3:
|
|
542
|
+
return
|
|
543
|
+
try:
|
|
544
|
+
create_time = resu[1]
|
|
545
|
+
level = resu[2]
|
|
546
|
+
funcname = resu[3]
|
|
547
|
+
line = resu[4]
|
|
548
|
+
pathname = resu[5]
|
|
549
|
+
message = resu[6]
|
|
550
|
+
|
|
551
|
+
message_list = message.split("&")
|
|
552
|
+
if len(message_list) == 3:
|
|
553
|
+
target, type_, content = message_list
|
|
554
|
+
elif len(message_list) == 2:
|
|
555
|
+
target, type_ = message_list
|
|
556
|
+
content = "只有两个"
|
|
557
|
+
elif len(message_list) == 1:
|
|
558
|
+
target = message_list[0]
|
|
559
|
+
type_ = " "
|
|
560
|
+
content = "只有一个"
|
|
561
|
+
|
|
562
|
+
dt_object = datetime.datetime.fromtimestamp(float(create_time.strip()))
|
|
563
|
+
use_case = UseCase(
|
|
564
|
+
time = create_time,
|
|
565
|
+
level = level,
|
|
566
|
+
timestamp =dt_object.strftime('%Y-%m-%d %H:%M:%S.%f'),
|
|
567
|
+
filepath=pathname,
|
|
568
|
+
function=funcname,
|
|
569
|
+
lines=line,
|
|
570
|
+
type_=type_,
|
|
571
|
+
target=target,
|
|
572
|
+
content=content,
|
|
934
573
|
)
|
|
574
|
+
session.add(use_case)
|
|
575
|
+
except Exception as e:
|
|
576
|
+
print(resu,'resu')
|
|
577
|
+
raise
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
for res in x.split("||"):
|
|
581
|
+
resu = res.split("$")
|
|
582
|
+
deal_log(resu)
|
|
583
|
+
|
|
584
|
+
await session.commit() # 提交事务,将数据写入数据库
|
|
935
585
|
|
|
936
|
-
#######
|
|
937
|
-
kwargs.update({"input_data":output_})
|
|
938
|
-
result = await func(*args, **kwargs)
|
|
939
|
-
return result
|
|
940
|
-
return wrapper
|
|
941
|
-
return outer_packing
|
|
942
586
|
|
|
943
587
|
async def intellect_format_eval(self,
|
|
944
|
-
OutputFormat: object,
|
|
945
588
|
prompt_id: str,
|
|
589
|
+
version: str = None,
|
|
946
590
|
database_url = None,
|
|
591
|
+
OutputFormat: object = None,
|
|
947
592
|
ExtraFormats: list[object] = [],
|
|
948
|
-
version: str = None,
|
|
949
593
|
MIN_SUCCESS_RATE = 80.0,
|
|
950
594
|
ConTent_Function = None,
|
|
951
595
|
AConTent_Function = None,
|
|
596
|
+
start = None,
|
|
597
|
+
end = None,
|
|
952
598
|
):
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
# TODO 人类评价 eval
|
|
956
|
-
# TODO llm 评价 eval
|
|
957
|
-
"""
|
|
599
|
+
# start = datetime(2023, 1, 1, 10, 0, 0)
|
|
600
|
+
# end = datetime(2023, 1, 15, 12, 30, 0)
|
|
958
601
|
async with create_async_session(self.engine) as session:
|
|
959
602
|
prompt_result = await self.get_prompt_safe(prompt_id=prompt_id,
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
raise IntellectRemoveError("不存在的prompt_id")
|
|
963
|
-
if prompt_result.action_type != "inference":
|
|
964
|
-
raise IntellectRemoveError("请在inference模式下使用次类")
|
|
603
|
+
version = version,
|
|
604
|
+
session=session)
|
|
965
605
|
|
|
966
606
|
if database_url:
|
|
967
607
|
eval_engine = create_async_engine(database_url, echo=False,
|
|
@@ -973,12 +613,11 @@ class AsyncIntel():
|
|
|
973
613
|
)
|
|
974
614
|
else:
|
|
975
615
|
eval_engine = self.engine
|
|
616
|
+
|
|
976
617
|
async with create_async_session(eval_engine) as eval_session:
|
|
977
|
-
# start = datetime(2023, 1, 1, 10, 0, 0)
|
|
978
|
-
# end = datetime(2023, 1, 15, 12, 30, 0)
|
|
979
618
|
use_cases = await self.get_use_case(target_prompt_id=prompt_id,session=eval_session,
|
|
980
|
-
start_time=
|
|
981
|
-
end_time=
|
|
619
|
+
start_time=start,
|
|
620
|
+
end_time=end,)
|
|
982
621
|
|
|
983
622
|
total_assertions = len(use_cases)
|
|
984
623
|
result_cases = []
|
|
@@ -1038,21 +677,13 @@ class AsyncIntel():
|
|
|
1038
677
|
|
|
1039
678
|
success_rate = (successful_assertions / total_assertions) * 100
|
|
1040
679
|
|
|
680
|
+
status = "通过" if success_rate >= MIN_SUCCESS_RATE else "未通过"
|
|
1041
681
|
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
'status':"通过",
|
|
682
|
+
self.eval_df.loc[len(self.eval_df)] = {"name":prompt_id,
|
|
683
|
+
'status':status,
|
|
1045
684
|
"score":success_rate,
|
|
1046
685
|
"total":str(total_assertions),
|
|
1047
686
|
"bad_case":json.dumps(bad_case,ensure_ascii=False)}
|
|
1048
|
-
return "通过", success_rate, str(total_assertions), json.dumps(bad_case,ensure_ascii=False),
|
|
1049
|
-
else:
|
|
1050
|
-
self.eval_df.loc[len(self.eval_df)] = {"name":prompt_id,
|
|
1051
|
-
'status':"未通过",
|
|
1052
|
-
"score":success_rate,
|
|
1053
|
-
"total":str(total_assertions),
|
|
1054
|
-
"bad_case":json.dumps(bad_case,ensure_ascii=False)}
|
|
1055
|
-
return "未通过",success_rate, str(total_assertions), json.dumps(bad_case,ensure_ascii=False),
|
|
1056
687
|
|
|
1057
688
|
|
|
1058
689
|
async def function_eval(self,
|
|
@@ -1071,12 +702,8 @@ class AsyncIntel():
|
|
|
1071
702
|
# TODO llm 评价 eval
|
|
1072
703
|
"""
|
|
1073
704
|
async with create_async_session(self.engine) as session:
|
|
1074
|
-
|
|
705
|
+
await self.get_prompt_safe(prompt_id=prompt_id,
|
|
1075
706
|
session=session)
|
|
1076
|
-
if prompt_result is None:
|
|
1077
|
-
raise IntellectRemoveError("不存在的prompt_id")
|
|
1078
|
-
if prompt_result.action_type != "inference":
|
|
1079
|
-
raise IntellectRemoveError("请在inference模式下使用次类")
|
|
1080
707
|
|
|
1081
708
|
if database_url:
|
|
1082
709
|
eval_engine = create_async_engine(database_url, echo=False,
|