flowllm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flowllm-0.1.0.dist-info/METADATA +597 -0
- flowllm-0.1.0.dist-info/RECORD +66 -0
- flowllm-0.1.0.dist-info/WHEEL +5 -0
- flowllm-0.1.0.dist-info/entry_points.txt +3 -0
- flowllm-0.1.0.dist-info/licenses/LICENSE +201 -0
- flowllm-0.1.0.dist-info/top_level.txt +1 -0
- llmflow/__init__.py +0 -0
- llmflow/app.py +53 -0
- llmflow/config/__init__.py +0 -0
- llmflow/config/config_parser.py +80 -0
- llmflow/config/mock_config.yaml +58 -0
- llmflow/embedding_model/__init__.py +5 -0
- llmflow/embedding_model/base_embedding_model.py +104 -0
- llmflow/embedding_model/openai_compatible_embedding_model.py +95 -0
- llmflow/enumeration/__init__.py +0 -0
- llmflow/enumeration/agent_state.py +8 -0
- llmflow/enumeration/chunk_enum.py +9 -0
- llmflow/enumeration/http_enum.py +9 -0
- llmflow/enumeration/role.py +8 -0
- llmflow/llm/__init__.py +5 -0
- llmflow/llm/base_llm.py +138 -0
- llmflow/llm/openai_compatible_llm.py +283 -0
- llmflow/mcp_server.py +110 -0
- llmflow/op/__init__.py +10 -0
- llmflow/op/base_op.py +125 -0
- llmflow/op/mock_op.py +40 -0
- llmflow/op/prompt_mixin.py +74 -0
- llmflow/op/react/__init__.py +0 -0
- llmflow/op/react/react_v1_op.py +88 -0
- llmflow/op/react/react_v1_prompt.yaml +28 -0
- llmflow/op/vector_store/__init__.py +13 -0
- llmflow/op/vector_store/recall_vector_store_op.py +48 -0
- llmflow/op/vector_store/update_vector_store_op.py +28 -0
- llmflow/op/vector_store/vector_store_action_op.py +46 -0
- llmflow/pipeline/__init__.py +0 -0
- llmflow/pipeline/pipeline.py +94 -0
- llmflow/pipeline/pipeline_context.py +37 -0
- llmflow/schema/__init__.py +0 -0
- llmflow/schema/app_config.py +69 -0
- llmflow/schema/experience.py +144 -0
- llmflow/schema/message.py +68 -0
- llmflow/schema/request.py +32 -0
- llmflow/schema/response.py +29 -0
- llmflow/schema/vector_node.py +11 -0
- llmflow/service/__init__.py +0 -0
- llmflow/service/llmflow_service.py +96 -0
- llmflow/tool/__init__.py +9 -0
- llmflow/tool/base_tool.py +80 -0
- llmflow/tool/code_tool.py +43 -0
- llmflow/tool/dashscope_search_tool.py +162 -0
- llmflow/tool/mcp_tool.py +77 -0
- llmflow/tool/tavily_search_tool.py +109 -0
- llmflow/tool/terminate_tool.py +23 -0
- llmflow/utils/__init__.py +0 -0
- llmflow/utils/common_utils.py +17 -0
- llmflow/utils/file_handler.py +25 -0
- llmflow/utils/http_client.py +156 -0
- llmflow/utils/op_utils.py +102 -0
- llmflow/utils/registry.py +33 -0
- llmflow/utils/singleton.py +9 -0
- llmflow/utils/timer.py +53 -0
- llmflow/vector_store/__init__.py +7 -0
- llmflow/vector_store/base_vector_store.py +136 -0
- llmflow/vector_store/chroma_vector_store.py +188 -0
- llmflow/vector_store/es_vector_store.py +227 -0
- llmflow/vector_store/file_vector_store.py +163 -0
llmflow/tool/mcp_tool.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
import asyncio
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
from mcp import ClientSession
|
5
|
+
from mcp.client.sse import sse_client
|
6
|
+
from pydantic import Field, model_validator
|
7
|
+
|
8
|
+
from llmflow.tool import TOOL_REGISTRY
|
9
|
+
from llmflow.tool.base_tool import BaseTool
|
10
|
+
|
11
|
+
|
12
|
+
@TOOL_REGISTRY.register()
|
13
|
+
class MCPTool(BaseTool):
|
14
|
+
server_url: str = Field(..., description="MCP server URL")
|
15
|
+
tool_name_list: List[str] = Field(default_factory=list)
|
16
|
+
cache_tools: dict = Field(default_factory=dict, alias="cache_tools")
|
17
|
+
|
18
|
+
@model_validator(mode="after")
|
19
|
+
def refresh_tools(self):
|
20
|
+
self.refresh()
|
21
|
+
return self
|
22
|
+
|
23
|
+
async def _get_tools(self):
|
24
|
+
async with sse_client(url=self.server_url) as streams:
|
25
|
+
async with ClientSession(streams[0], streams[1]) as session:
|
26
|
+
await session.initialize()
|
27
|
+
tools = await session.list_tools()
|
28
|
+
return tools
|
29
|
+
|
30
|
+
def refresh(self):
|
31
|
+
self.tool_name_list.clear()
|
32
|
+
self.cache_tools.clear()
|
33
|
+
|
34
|
+
if "sse" in self.server_url:
|
35
|
+
original_tool_list = asyncio.run(self._get_tools())
|
36
|
+
for tool in original_tool_list.tools:
|
37
|
+
self.cache_tools[tool.name] = tool
|
38
|
+
self.tool_name_list.append(tool.name)
|
39
|
+
else:
|
40
|
+
raise NotImplementedError("Non-SSE refresh not implemented yet")
|
41
|
+
|
42
|
+
@property
|
43
|
+
def input_schema(self) -> dict:
|
44
|
+
return {x: self.cache_tools[x].inputSchema for x in self.cache_tools}
|
45
|
+
|
46
|
+
@property
|
47
|
+
def output_schema(self) -> dict:
|
48
|
+
raise NotImplementedError("Output schema not implemented yet")
|
49
|
+
|
50
|
+
def get_tool_description(self, tool_name: str, schema: bool = False) -> str:
|
51
|
+
if tool_name not in self.cache_tools:
|
52
|
+
raise RuntimeError(f"Tool {tool_name} not found")
|
53
|
+
|
54
|
+
tool = self.cache_tools.get(tool_name)
|
55
|
+
description = f"tool={tool_name} description={tool.description}\n"
|
56
|
+
if schema:
|
57
|
+
description += f"input_schema={self.input_schema[tool_name]}\n" \
|
58
|
+
f"output_schema={self.output_schema[tool_name]}\n"
|
59
|
+
return description.strip()
|
60
|
+
|
61
|
+
async def async_execute(self, tool_name: str, **kwargs):
|
62
|
+
if "sse" in self.server_url:
|
63
|
+
async with sse_client(url=self.server_url) as streams:
|
64
|
+
async with ClientSession(streams[0], streams[1]) as session:
|
65
|
+
await session.initialize()
|
66
|
+
results = await session.call_tool(tool_name, kwargs)
|
67
|
+
return results.content[0].text, results.isError
|
68
|
+
|
69
|
+
else:
|
70
|
+
raise NotImplementedError("Non-SSE execute not implemented yet")
|
71
|
+
|
72
|
+
def _execute(self, **kwargs):
|
73
|
+
return asyncio.run(self.async_execute(**kwargs))
|
74
|
+
|
75
|
+
def get_cache_id(self, **kwargs) -> str:
|
76
|
+
# Implement a method to generate a unique cache ID based on the input
|
77
|
+
return f"{kwargs.get('tool_name')}_{hash(frozenset(kwargs.get('args', {}).items()))}"
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import re
|
4
|
+
import time
|
5
|
+
from typing import Literal
|
6
|
+
|
7
|
+
from loguru import logger
|
8
|
+
from pydantic import Field, model_validator, PrivateAttr
|
9
|
+
from tavily import TavilyClient
|
10
|
+
|
11
|
+
from llmflow.tool import TOOL_REGISTRY
|
12
|
+
from llmflow.tool.base_tool import BaseTool
|
13
|
+
|
14
|
+
|
15
|
+
@TOOL_REGISTRY.register()
|
16
|
+
class TavilySearchTool(BaseTool):
|
17
|
+
name: str = "web_search"
|
18
|
+
description: str = "Use query to retrieve relevant information from the internet."
|
19
|
+
parameters: dict = {
|
20
|
+
"type": "object",
|
21
|
+
"properties": {
|
22
|
+
"query": {
|
23
|
+
"type": "string",
|
24
|
+
"description": "search query",
|
25
|
+
}
|
26
|
+
},
|
27
|
+
"required": ["query"]
|
28
|
+
}
|
29
|
+
enable_print: bool = Field(default=True)
|
30
|
+
enable_cache: bool = Field(default=False)
|
31
|
+
cache_path: str = Field(default="./web_search_cache")
|
32
|
+
topic: Literal["general", "news", "finance"] = Field(default="general", description="finance, general")
|
33
|
+
|
34
|
+
_client: TavilyClient | None = PrivateAttr()
|
35
|
+
|
36
|
+
@model_validator(mode="after")
|
37
|
+
def init(self):
|
38
|
+
if not os.path.exists(self.cache_path):
|
39
|
+
os.makedirs(self.cache_path)
|
40
|
+
|
41
|
+
self._client = TavilyClient()
|
42
|
+
return self
|
43
|
+
|
44
|
+
def load_cache(self, cache_name: str = "default") -> dict:
|
45
|
+
cache_file = os.path.join(self.cache_path, cache_name + ".jsonl")
|
46
|
+
if not os.path.exists(cache_file):
|
47
|
+
return {}
|
48
|
+
|
49
|
+
with open(cache_file) as f:
|
50
|
+
return json.load(f)
|
51
|
+
|
52
|
+
def dump_cache(self, cache_dict: dict, cache_name: str = "default"):
|
53
|
+
cache_file = os.path.join(self.cache_path, cache_name + ".jsonl")
|
54
|
+
with open(cache_file, "w") as f:
|
55
|
+
return json.dump(cache_dict, f, indent=2, ensure_ascii=False)
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def remove_urls_and_images(text):
|
59
|
+
pattern = re.compile(r'https?://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]')
|
60
|
+
result = pattern.sub("", text)
|
61
|
+
return result
|
62
|
+
|
63
|
+
def post_process(self, response):
|
64
|
+
if self.enable_print:
|
65
|
+
logger.info("response=\n" + json.dumps(response, indent=2, ensure_ascii=False))
|
66
|
+
|
67
|
+
return response
|
68
|
+
|
69
|
+
def execute(self, query: str = "", **kwargs):
|
70
|
+
assert query, "Query cannot be empty"
|
71
|
+
|
72
|
+
cache_dict = {}
|
73
|
+
if self.enable_cache:
|
74
|
+
cache_dict = self.load_cache()
|
75
|
+
if query in cache_dict:
|
76
|
+
return self.post_process(cache_dict[query])
|
77
|
+
|
78
|
+
for i in range(self.max_retries):
|
79
|
+
try:
|
80
|
+
response = self._client.search(query=query, topic=self.topic)
|
81
|
+
url_info_dict = {item["url"]: item for item in response["results"]}
|
82
|
+
response_extract = self._client.extract(urls=[item["url"] for item in response["results"]],
|
83
|
+
format="text")
|
84
|
+
|
85
|
+
final_result = {}
|
86
|
+
for item in response_extract["results"]:
|
87
|
+
url = item["url"]
|
88
|
+
final_result[url] = url_info_dict[url]
|
89
|
+
final_result[url]["raw_content"] = item["raw_content"]
|
90
|
+
|
91
|
+
if self.enable_cache:
|
92
|
+
cache_dict[query] = final_result
|
93
|
+
self.dump_cache(cache_dict)
|
94
|
+
|
95
|
+
return self.post_process(final_result)
|
96
|
+
|
97
|
+
except Exception as e:
|
98
|
+
logger.exception(f"tavily search with query={query} encounter error with e={e.args}")
|
99
|
+
time.sleep(i + 1)
|
100
|
+
|
101
|
+
return None
|
102
|
+
|
103
|
+
|
104
|
+
if __name__ == "__main__":
|
105
|
+
from dotenv import load_dotenv
|
106
|
+
|
107
|
+
load_dotenv()
|
108
|
+
tool = TavilySearchTool()
|
109
|
+
tool.execute(query="A股医药为什么一直涨")
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from llmflow.tool import TOOL_REGISTRY
|
2
|
+
from llmflow.tool.base_tool import BaseTool
|
3
|
+
|
4
|
+
|
5
|
+
@TOOL_REGISTRY.register()
|
6
|
+
class TerminateTool(BaseTool):
|
7
|
+
name: str = "terminate"
|
8
|
+
description: str = "If you can answer the user's question based on the context, be sure to use the **terminate** tool."
|
9
|
+
parameters: dict = {
|
10
|
+
"type": "object",
|
11
|
+
"properties": {
|
12
|
+
"status": {
|
13
|
+
"type": "string",
|
14
|
+
"description": "Please determine whether the user's question has been completed. (success / failure)",
|
15
|
+
"enum": ["success", "failure"],
|
16
|
+
}
|
17
|
+
},
|
18
|
+
"required": ["status"],
|
19
|
+
}
|
20
|
+
|
21
|
+
def execute(self, status: str):
|
22
|
+
self.success = status in ["success", "failure"]
|
23
|
+
return f"The interaction has been completed with status: {status}"
|
File without changes
|
@@ -0,0 +1,17 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
|
4
|
+
def camel_to_snake(content: str) -> str:
|
5
|
+
"""
|
6
|
+
BaseWorker -> base_worker
|
7
|
+
"""
|
8
|
+
snake_str = re.sub(r'(?<!^)(?=[A-Z])', '_', content).lower()
|
9
|
+
return snake_str
|
10
|
+
|
11
|
+
|
12
|
+
def snake_to_camel(content: str) -> str:
|
13
|
+
"""
|
14
|
+
base_worker -> BaseWorker
|
15
|
+
"""
|
16
|
+
camel_str = "".join(x.capitalize() for x in content.split("_"))
|
17
|
+
return camel_str
|
@@ -0,0 +1,25 @@
|
|
1
|
+
import json
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import yaml
|
5
|
+
|
6
|
+
|
7
|
+
class FileHandler:
|
8
|
+
|
9
|
+
def __init__(self, file_path: str | Path):
|
10
|
+
self.file_path: Path = Path(file_path)
|
11
|
+
suffix = Path(self.file_path).suffix
|
12
|
+
if suffix == ".json":
|
13
|
+
self._obj = json
|
14
|
+
elif suffix == ".yaml":
|
15
|
+
self._obj = yaml
|
16
|
+
else:
|
17
|
+
raise ValueError(f"unsupported file type={suffix}")
|
18
|
+
|
19
|
+
def dump(self, config, **kwargs):
|
20
|
+
with open(self.file_path, "w") as f:
|
21
|
+
self._obj.dump(config, f, **kwargs)
|
22
|
+
|
23
|
+
def load(self, **kwargs):
|
24
|
+
with open(self.file_path, "r") as f:
|
25
|
+
return self._obj.load(f, **kwargs)
|
@@ -0,0 +1,156 @@
|
|
1
|
+
import http
|
2
|
+
import time
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import requests
|
6
|
+
from loguru import logger
|
7
|
+
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
8
|
+
|
9
|
+
from llmflow.enumeration.http_enum import HttpEnum
|
10
|
+
|
11
|
+
|
12
|
+
class HttpClient(BaseModel):
|
13
|
+
url: str = Field(default="")
|
14
|
+
keep_alive: bool = Field(default=False, description="if true, use session to keep long connection")
|
15
|
+
timeout: int = Field(default=300, description="request timeout, second")
|
16
|
+
|
17
|
+
return_default_if_error: bool = Field(default=True)
|
18
|
+
request_start_time: float = Field(default_factory=time.time)
|
19
|
+
request_time_cost: float = Field(default=0.0, description="request time cost")
|
20
|
+
|
21
|
+
retry_sleep_time: float = Field(default=0.5, description="interval time for retry")
|
22
|
+
retry_time_multiplier: float = Field(default=2.0, description="retry time multiplier")
|
23
|
+
retry_max_count: int = Field(default=1, description="maximum number of retries")
|
24
|
+
|
25
|
+
_client: Any = PrivateAttr()
|
26
|
+
|
27
|
+
@model_validator(mode="after")
|
28
|
+
def init_client(self):
|
29
|
+
self._client = requests.Session() if self.keep_alive else requests
|
30
|
+
return self
|
31
|
+
|
32
|
+
def __enter__(self):
|
33
|
+
return self
|
34
|
+
|
35
|
+
def __exit__(self, *args):
|
36
|
+
self.close()
|
37
|
+
self.request_time_cost: float = time.time() - self.request_start_time
|
38
|
+
|
39
|
+
def close(self):
|
40
|
+
if isinstance(self._client, requests.Session):
|
41
|
+
self._client.close()
|
42
|
+
|
43
|
+
def _request(self,
|
44
|
+
data: str = None,
|
45
|
+
json_data: dict = None,
|
46
|
+
headers: dict = None,
|
47
|
+
stream: bool = False,
|
48
|
+
http_enum: HttpEnum | str = HttpEnum.POST):
|
49
|
+
|
50
|
+
if isinstance(http_enum, str):
|
51
|
+
http_enum = HttpEnum(http_enum)
|
52
|
+
|
53
|
+
if http_enum is HttpEnum.POST:
|
54
|
+
response: requests.Response = self._client.post(url=self.url,
|
55
|
+
data=data,
|
56
|
+
json=json_data,
|
57
|
+
headers=headers,
|
58
|
+
stream=stream,
|
59
|
+
timeout=self.timeout)
|
60
|
+
|
61
|
+
elif http_enum is HttpEnum.GET:
|
62
|
+
response: requests.Response = self._client.get(url=self.url,
|
63
|
+
data=data,
|
64
|
+
json=json_data,
|
65
|
+
headers=headers,
|
66
|
+
stream=stream,
|
67
|
+
timeout=self.timeout)
|
68
|
+
|
69
|
+
else:
|
70
|
+
raise NotImplementedError
|
71
|
+
|
72
|
+
if response.status_code != http.HTTPStatus.OK:
|
73
|
+
raise RuntimeError(f"request failed! content={response.json()}")
|
74
|
+
|
75
|
+
return response
|
76
|
+
|
77
|
+
def parse_result(self, response: requests.Response | Any = None, **kwargs):
|
78
|
+
return response.json()
|
79
|
+
|
80
|
+
def return_default(self, **kwargs):
|
81
|
+
return None
|
82
|
+
|
83
|
+
def request(self,
|
84
|
+
data: str | Any = None,
|
85
|
+
json_data: dict = None,
|
86
|
+
headers: dict = None,
|
87
|
+
http_enum: HttpEnum | str = HttpEnum.POST,
|
88
|
+
**kwargs):
|
89
|
+
|
90
|
+
retry_sleep_time = self.retry_sleep_time
|
91
|
+
for i in range(self.retry_max_count):
|
92
|
+
try:
|
93
|
+
response = self._request(data=data, json_data=json_data, headers=headers, http_enum=http_enum)
|
94
|
+
result = self.parse_result(response=response,
|
95
|
+
data=data,
|
96
|
+
json_data=json_data,
|
97
|
+
headers=headers,
|
98
|
+
http_enum=http_enum,
|
99
|
+
**kwargs)
|
100
|
+
return result
|
101
|
+
|
102
|
+
except Exception as e:
|
103
|
+
logger.exception(f"{self.__class__.__name__} {i}th request failed with args={e.args}")
|
104
|
+
|
105
|
+
if i == self.retry_max_count - 1:
|
106
|
+
if self.return_default_if_error:
|
107
|
+
return self.return_default()
|
108
|
+
else:
|
109
|
+
raise e
|
110
|
+
|
111
|
+
retry_sleep_time *= self.retry_time_multiplier
|
112
|
+
time.sleep(retry_sleep_time)
|
113
|
+
|
114
|
+
return None
|
115
|
+
|
116
|
+
def request_stream(self,
|
117
|
+
data: str = None,
|
118
|
+
json_data: dict = None,
|
119
|
+
headers: dict = None,
|
120
|
+
http_enum: HttpEnum | str = HttpEnum.POST,
|
121
|
+
**kwargs):
|
122
|
+
|
123
|
+
retry_sleep_time = self.retry_sleep_time
|
124
|
+
for i in range(self.retry_max_count):
|
125
|
+
try:
|
126
|
+
response = self._request(data=data,
|
127
|
+
json_data=json_data,
|
128
|
+
headers=headers,
|
129
|
+
stream=True,
|
130
|
+
http_enum=http_enum)
|
131
|
+
request_context = {}
|
132
|
+
for iter_idx, line in enumerate(response.iter_lines()):
|
133
|
+
yield self.parse_result(line=line,
|
134
|
+
request_context=request_context,
|
135
|
+
index=iter_idx,
|
136
|
+
data=data,
|
137
|
+
json_data=json_data,
|
138
|
+
headers=headers,
|
139
|
+
http_enum=http_enum,
|
140
|
+
**kwargs)
|
141
|
+
|
142
|
+
return None
|
143
|
+
|
144
|
+
except Exception as e:
|
145
|
+
logger.exception(f"{self.__class__.__name__} {i}th request failed with args={e.args}")
|
146
|
+
|
147
|
+
if i == self.retry_max_count - 1:
|
148
|
+
if self.return_default_if_error:
|
149
|
+
return self.return_default()
|
150
|
+
else:
|
151
|
+
raise e
|
152
|
+
|
153
|
+
retry_sleep_time *= self.retry_time_multiplier
|
154
|
+
time.sleep(retry_sleep_time)
|
155
|
+
|
156
|
+
return None
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from llmflow.enumeration.role import Role
|
4
|
+
from llmflow.schema.message import Message, Trajectory
|
5
|
+
import json
|
6
|
+
import re
|
7
|
+
from loguru import logger
|
8
|
+
|
9
|
+
def merge_messages_content(messages: List[Message | dict]) -> str:
|
10
|
+
content_collector = []
|
11
|
+
for i, message in enumerate(messages):
|
12
|
+
if isinstance(message, dict):
|
13
|
+
message = Message(**message)
|
14
|
+
|
15
|
+
if message.role is Role.ASSISTANT:
|
16
|
+
line = f"### step.{i} role={message.role.value} content=\n{message.reasoning_content}\n\n{message.content}\n"
|
17
|
+
if message.tool_calls:
|
18
|
+
for tool_call in message.tool_calls:
|
19
|
+
line += f" - tool call={tool_call.name}\n params={tool_call.arguments}\n"
|
20
|
+
content_collector.append(line)
|
21
|
+
|
22
|
+
elif message.role is Role.USER:
|
23
|
+
line = f"### step.{i} role={message.role.value} content=\n{message.content}\n"
|
24
|
+
content_collector.append(line)
|
25
|
+
|
26
|
+
elif message.role is Role.TOOL:
|
27
|
+
line = f"### step.{i} role={message.role.value} tool call result=\n{message.content}\n"
|
28
|
+
content_collector.append(line)
|
29
|
+
|
30
|
+
return "\n".join(content_collector)
|
31
|
+
|
32
|
+
|
33
|
+
def parse_json_experience_response(response: str) -> List[dict]:
|
34
|
+
"""Parse JSON formatted experience response"""
|
35
|
+
try:
|
36
|
+
# Extract JSON blocks
|
37
|
+
json_pattern = r'```json\s*([\s\S]*?)\s*```'
|
38
|
+
json_blocks = re.findall(json_pattern, response)
|
39
|
+
|
40
|
+
if json_blocks:
|
41
|
+
parsed = json.loads(json_blocks[0])
|
42
|
+
|
43
|
+
# Handle array format
|
44
|
+
if isinstance(parsed, list):
|
45
|
+
experiences = []
|
46
|
+
for exp_data in parsed:
|
47
|
+
if isinstance(exp_data, dict) and (
|
48
|
+
("when_to_use" in exp_data and "experience" in exp_data) or
|
49
|
+
("condition" in exp_data and "experience" in exp_data)
|
50
|
+
):
|
51
|
+
experiences.append(exp_data)
|
52
|
+
|
53
|
+
return experiences
|
54
|
+
|
55
|
+
|
56
|
+
# Handle single object
|
57
|
+
elif isinstance(parsed, dict) and (
|
58
|
+
("when_to_use" in parsed and "experience" in parsed) or
|
59
|
+
("condition" in parsed and "experience" in parsed)
|
60
|
+
):
|
61
|
+
return [parsed]
|
62
|
+
|
63
|
+
# Fallback: try to parse entire response
|
64
|
+
parsed = json.loads(response)
|
65
|
+
if isinstance(parsed, list):
|
66
|
+
return parsed
|
67
|
+
elif isinstance(parsed, dict):
|
68
|
+
return [parsed]
|
69
|
+
|
70
|
+
except json.JSONDecodeError as e:
|
71
|
+
logger.warning(f"Failed to parse JSON experience response: {e}")
|
72
|
+
|
73
|
+
return []
|
74
|
+
|
75
|
+
def get_trajectory_context(trajectory: Trajectory, step_sequence: List[Message]) -> str:
|
76
|
+
"""Get context of step sequence within trajectory"""
|
77
|
+
try:
|
78
|
+
# Find position of step sequence in trajectory
|
79
|
+
start_idx = 0
|
80
|
+
for i, step in enumerate(trajectory.messages):
|
81
|
+
if step == step_sequence[0]:
|
82
|
+
start_idx = i
|
83
|
+
break
|
84
|
+
|
85
|
+
# Extract before and after context
|
86
|
+
context_before = trajectory.messages[max(0, start_idx - 2):start_idx]
|
87
|
+
context_after = trajectory.messages[start_idx + len(step_sequence):start_idx + len(step_sequence) + 2]
|
88
|
+
|
89
|
+
context = f"Query: {trajectory.metadata.get('query', 'N/A')}\n"
|
90
|
+
|
91
|
+
if context_before:
|
92
|
+
context += "Previous steps:\n" + "\n".join(
|
93
|
+
[f"- {step.content[:100]}..." for step in context_before]) + "\n"
|
94
|
+
|
95
|
+
if context_after:
|
96
|
+
context += "Following steps:\n" + "\n".join([f"- {step.content[:100]}..." for step in context_after])
|
97
|
+
|
98
|
+
return context
|
99
|
+
|
100
|
+
except Exception as e:
|
101
|
+
logger.error(f"Error getting trajectory context: {e}")
|
102
|
+
return f"Query: {trajectory.metadata.get('query', 'N/A')}"
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from loguru import logger
|
4
|
+
|
5
|
+
from llmflow.utils.common_utils import camel_to_snake
|
6
|
+
|
7
|
+
|
8
|
+
class Registry(object):
|
9
|
+
def __init__(self):
|
10
|
+
self._registry = {}
|
11
|
+
|
12
|
+
def register(self, name: str = ""):
|
13
|
+
|
14
|
+
def decorator(cls):
|
15
|
+
class_name = name if name else camel_to_snake(cls.__name__)
|
16
|
+
if class_name in self._registry:
|
17
|
+
logger.warning(f"name={class_name} is already registered, will be overwritten.")
|
18
|
+
self._registry[class_name] = cls
|
19
|
+
return cls
|
20
|
+
|
21
|
+
return decorator
|
22
|
+
|
23
|
+
def __getitem__(self, name: str):
|
24
|
+
if name not in self._registry:
|
25
|
+
raise KeyError(f"name={name} is not registered!")
|
26
|
+
return self._registry[name]
|
27
|
+
|
28
|
+
def __contains__(self, name: str):
|
29
|
+
return name in self._registry
|
30
|
+
|
31
|
+
@property
|
32
|
+
def registered_names(self) -> List[str]:
|
33
|
+
return sorted(self._registry.keys())
|
llmflow/utils/timer.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
from loguru import logger
|
4
|
+
|
5
|
+
|
6
|
+
class Timer(object):
|
7
|
+
def __init__(self, name: str, use_ms: bool = False, stack_level: int = 2):
|
8
|
+
self.name: str = name
|
9
|
+
self.use_ms: bool = use_ms
|
10
|
+
self.stack_level: int = stack_level
|
11
|
+
|
12
|
+
self.time_start: float = 0
|
13
|
+
self.time_end: float = 0
|
14
|
+
self.time_cost: float = 0
|
15
|
+
|
16
|
+
def __enter__(self, *args, **kwargs):
|
17
|
+
self.time_start = time.time()
|
18
|
+
logger.info(f"---------- enter {self.name} ----------", stacklevel=self.stack_level)
|
19
|
+
return self
|
20
|
+
|
21
|
+
def __exit__(self, *args):
|
22
|
+
self.time_end = time.time()
|
23
|
+
self.time_cost = self.time_end - self.time_start
|
24
|
+
if self.use_ms:
|
25
|
+
time_str = f"{self.time_cost * 1000:.2f}ms"
|
26
|
+
else:
|
27
|
+
time_str = f"{self.time_cost:.3f}s"
|
28
|
+
|
29
|
+
logger.info(f"---------- leave {self.name} [{time_str}] ----------", stacklevel=self.stack_level)
|
30
|
+
|
31
|
+
|
32
|
+
def timer(name: str = None, use_ms: bool = False, stack_level: int = 2):
|
33
|
+
def decorator(func):
|
34
|
+
def wrapper(*args, **kwargs):
|
35
|
+
with Timer(name=name or func.__name__, use_ms=use_ms, stack_level=stack_level + 1):
|
36
|
+
return func(*args, **kwargs)
|
37
|
+
|
38
|
+
return wrapper
|
39
|
+
|
40
|
+
return decorator
|
41
|
+
|
42
|
+
|
43
|
+
if __name__ == "__main__":
|
44
|
+
import random
|
45
|
+
|
46
|
+
|
47
|
+
@timer("run_func_final", use_ms=True)
|
48
|
+
def run_func():
|
49
|
+
time.sleep(random.uniform(0.05, 0.15))
|
50
|
+
print("done")
|
51
|
+
|
52
|
+
|
53
|
+
run_func()
|
@@ -0,0 +1,7 @@
|
|
1
|
+
from llmflow.utils.registry import Registry
|
2
|
+
|
3
|
+
VECTOR_STORE_REGISTRY = Registry()
|
4
|
+
|
5
|
+
from llmflow.vector_store.es_vector_store import EsVectorStore
|
6
|
+
from llmflow.vector_store.chroma_vector_store import ChromaVectorStore
|
7
|
+
from llmflow.vector_store.file_vector_store import FileVectorStore
|