flowllm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flowllm-0.1.0.dist-info/METADATA +597 -0
- flowllm-0.1.0.dist-info/RECORD +66 -0
- flowllm-0.1.0.dist-info/WHEEL +5 -0
- flowllm-0.1.0.dist-info/entry_points.txt +3 -0
- flowllm-0.1.0.dist-info/licenses/LICENSE +201 -0
- flowllm-0.1.0.dist-info/top_level.txt +1 -0
- llmflow/__init__.py +0 -0
- llmflow/app.py +53 -0
- llmflow/config/__init__.py +0 -0
- llmflow/config/config_parser.py +80 -0
- llmflow/config/mock_config.yaml +58 -0
- llmflow/embedding_model/__init__.py +5 -0
- llmflow/embedding_model/base_embedding_model.py +104 -0
- llmflow/embedding_model/openai_compatible_embedding_model.py +95 -0
- llmflow/enumeration/__init__.py +0 -0
- llmflow/enumeration/agent_state.py +8 -0
- llmflow/enumeration/chunk_enum.py +9 -0
- llmflow/enumeration/http_enum.py +9 -0
- llmflow/enumeration/role.py +8 -0
- llmflow/llm/__init__.py +5 -0
- llmflow/llm/base_llm.py +138 -0
- llmflow/llm/openai_compatible_llm.py +283 -0
- llmflow/mcp_server.py +110 -0
- llmflow/op/__init__.py +10 -0
- llmflow/op/base_op.py +125 -0
- llmflow/op/mock_op.py +40 -0
- llmflow/op/prompt_mixin.py +74 -0
- llmflow/op/react/__init__.py +0 -0
- llmflow/op/react/react_v1_op.py +88 -0
- llmflow/op/react/react_v1_prompt.yaml +28 -0
- llmflow/op/vector_store/__init__.py +13 -0
- llmflow/op/vector_store/recall_vector_store_op.py +48 -0
- llmflow/op/vector_store/update_vector_store_op.py +28 -0
- llmflow/op/vector_store/vector_store_action_op.py +46 -0
- llmflow/pipeline/__init__.py +0 -0
- llmflow/pipeline/pipeline.py +94 -0
- llmflow/pipeline/pipeline_context.py +37 -0
- llmflow/schema/__init__.py +0 -0
- llmflow/schema/app_config.py +69 -0
- llmflow/schema/experience.py +144 -0
- llmflow/schema/message.py +68 -0
- llmflow/schema/request.py +32 -0
- llmflow/schema/response.py +29 -0
- llmflow/schema/vector_node.py +11 -0
- llmflow/service/__init__.py +0 -0
- llmflow/service/llmflow_service.py +96 -0
- llmflow/tool/__init__.py +9 -0
- llmflow/tool/base_tool.py +80 -0
- llmflow/tool/code_tool.py +43 -0
- llmflow/tool/dashscope_search_tool.py +162 -0
- llmflow/tool/mcp_tool.py +77 -0
- llmflow/tool/tavily_search_tool.py +109 -0
- llmflow/tool/terminate_tool.py +23 -0
- llmflow/utils/__init__.py +0 -0
- llmflow/utils/common_utils.py +17 -0
- llmflow/utils/file_handler.py +25 -0
- llmflow/utils/http_client.py +156 -0
- llmflow/utils/op_utils.py +102 -0
- llmflow/utils/registry.py +33 -0
- llmflow/utils/singleton.py +9 -0
- llmflow/utils/timer.py +53 -0
- llmflow/vector_store/__init__.py +7 -0
- llmflow/vector_store/base_vector_store.py +136 -0
- llmflow/vector_store/chroma_vector_store.py +188 -0
- llmflow/vector_store/es_vector_store.py +227 -0
- llmflow/vector_store/file_vector_store.py +163 -0
llmflow/op/mock_op.py
ADDED
@@ -0,0 +1,40 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
from loguru import logger
|
4
|
+
|
5
|
+
from llmflow.op import OP_REGISTRY
|
6
|
+
from llmflow.op.base_op import BaseOp
|
7
|
+
|
8
|
+
|
9
|
+
@OP_REGISTRY.register()
|
10
|
+
class Mock1Op(BaseOp):
|
11
|
+
def execute(self):
|
12
|
+
time.sleep(1)
|
13
|
+
a: int = self.op_params["a"]
|
14
|
+
b: str = self.op_params["b"]
|
15
|
+
logger.info(f"enter class={self.simple_name}. a={a} b={b}")
|
16
|
+
|
17
|
+
|
18
|
+
@OP_REGISTRY.register()
|
19
|
+
class Mock2Op(Mock1Op):
|
20
|
+
...
|
21
|
+
|
22
|
+
|
23
|
+
@OP_REGISTRY.register()
|
24
|
+
class Mock3Op(Mock1Op):
|
25
|
+
...
|
26
|
+
|
27
|
+
|
28
|
+
@OP_REGISTRY.register()
|
29
|
+
class Mock4Op(Mock1Op):
|
30
|
+
...
|
31
|
+
|
32
|
+
|
33
|
+
@OP_REGISTRY.register()
|
34
|
+
class Mock5Op(Mock1Op):
|
35
|
+
...
|
36
|
+
|
37
|
+
|
38
|
+
@OP_REGISTRY.register()
|
39
|
+
class Mock6Op(Mock1Op):
|
40
|
+
...
|
@@ -0,0 +1,74 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
import yaml
|
4
|
+
from loguru import logger
|
5
|
+
|
6
|
+
|
7
|
+
class PromptMixin:
|
8
|
+
|
9
|
+
def __init__(self):
|
10
|
+
self._prompt_dict: dict = {}
|
11
|
+
|
12
|
+
def load_prompt_by_file(self, prompt_file_path: Path | str = None):
|
13
|
+
if prompt_file_path is None:
|
14
|
+
return
|
15
|
+
|
16
|
+
if isinstance(prompt_file_path, str):
|
17
|
+
prompt_file_path = Path(prompt_file_path)
|
18
|
+
|
19
|
+
if not prompt_file_path.exists():
|
20
|
+
return
|
21
|
+
|
22
|
+
with prompt_file_path.open() as f:
|
23
|
+
prompt_dict = yaml.load(f, yaml.FullLoader)
|
24
|
+
self.load_prompt_dict(prompt_dict)
|
25
|
+
|
26
|
+
def load_prompt_dict(self, prompt_dict: dict = None):
|
27
|
+
if not prompt_dict:
|
28
|
+
return
|
29
|
+
|
30
|
+
for key, value in prompt_dict.items():
|
31
|
+
if isinstance(value, str):
|
32
|
+
if key in self._prompt_dict:
|
33
|
+
self._prompt_dict[key] = value
|
34
|
+
logger.warning(f"prompt_dict key={key} overwrite!")
|
35
|
+
|
36
|
+
else:
|
37
|
+
self._prompt_dict[key] = value
|
38
|
+
logger.info(f"add prompt_dict key={key}")
|
39
|
+
|
40
|
+
def prompt_format(self, prompt_name: str, **kwargs):
|
41
|
+
prompt = self._prompt_dict[prompt_name]
|
42
|
+
|
43
|
+
flag_kwargs = {k: v for k, v in kwargs.items() if isinstance(v, bool)}
|
44
|
+
other_kwargs = {k: v for k, v in kwargs.items() if not isinstance(v, bool)}
|
45
|
+
|
46
|
+
if flag_kwargs:
|
47
|
+
split_prompt = []
|
48
|
+
for line in prompt.strip().split("\n"):
|
49
|
+
hit = False
|
50
|
+
hit_flag = True
|
51
|
+
for key, flag in kwargs.items():
|
52
|
+
if not line.startswith(f"[{key}]"):
|
53
|
+
continue
|
54
|
+
|
55
|
+
else:
|
56
|
+
hit = True
|
57
|
+
hit_flag = flag
|
58
|
+
line = line.strip(f"[{key}]")
|
59
|
+
break
|
60
|
+
|
61
|
+
if not hit:
|
62
|
+
split_prompt.append(line)
|
63
|
+
elif hit_flag:
|
64
|
+
split_prompt.append(line)
|
65
|
+
|
66
|
+
prompt = "\n".join(split_prompt)
|
67
|
+
|
68
|
+
if other_kwargs:
|
69
|
+
prompt = prompt.format(**other_kwargs)
|
70
|
+
|
71
|
+
return prompt
|
72
|
+
|
73
|
+
def get_prompt(self, key: str):
|
74
|
+
return self._prompt_dict[key]
|
File without changes
|
@@ -0,0 +1,88 @@
|
|
1
|
+
import datetime
|
2
|
+
import time
|
3
|
+
from typing import List, Dict
|
4
|
+
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from llmflow.enumeration.role import Role
|
8
|
+
from llmflow.op import OP_REGISTRY
|
9
|
+
from llmflow.op.base_op import BaseOp
|
10
|
+
from llmflow.schema.message import Message
|
11
|
+
from llmflow.schema.request import AgentRequest
|
12
|
+
from llmflow.schema.response import AgentResponse
|
13
|
+
from llmflow.tool import TOOL_REGISTRY
|
14
|
+
from llmflow.tool.base_tool import BaseTool
|
15
|
+
|
16
|
+
|
17
|
+
@OP_REGISTRY.register()
|
18
|
+
class ReactV1Op(BaseOp):
|
19
|
+
current_path: str = __file__
|
20
|
+
|
21
|
+
def execute(self):
|
22
|
+
request: AgentRequest = self.context.request
|
23
|
+
response: AgentResponse = self.context.response
|
24
|
+
|
25
|
+
max_steps: int = int(self.op_params.get("max_steps", 10))
|
26
|
+
# dashscope_search_tool tavily_search_tool
|
27
|
+
tool_names = self.op_params.get("tool_names", "code_tool,tavily_search_tool,terminate_tool")
|
28
|
+
tools: List[BaseTool] = [TOOL_REGISTRY[x.strip()]() for x in tool_names.split(",") if x]
|
29
|
+
tool_dict: Dict[str, BaseTool] = {x.name: x for x in tools}
|
30
|
+
now_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
31
|
+
has_terminate_tool = False
|
32
|
+
|
33
|
+
user_prompt = self.prompt_format(prompt_name="role_prompt",
|
34
|
+
time=now_time,
|
35
|
+
tools=",".join([x.name for x in tools]),
|
36
|
+
query=request.query)
|
37
|
+
messages: List[Message] = [Message(role=Role.USER, content=user_prompt)]
|
38
|
+
logger.info(f"step.0 user_prompt={user_prompt}")
|
39
|
+
|
40
|
+
for i in range(max_steps):
|
41
|
+
if has_terminate_tool:
|
42
|
+
assistant_message: Message = self.llm.chat(messages)
|
43
|
+
else:
|
44
|
+
assistant_message: Message = self.llm.chat(messages, tools=tools)
|
45
|
+
|
46
|
+
messages.append(assistant_message)
|
47
|
+
logger.info(f"assistant.{i}.reasoning_content={assistant_message.reasoning_content}\n"
|
48
|
+
f"content={assistant_message.content}\n"
|
49
|
+
f"tool.size={len(assistant_message.tool_calls)}")
|
50
|
+
|
51
|
+
if has_terminate_tool:
|
52
|
+
break
|
53
|
+
|
54
|
+
for tool in assistant_message.tool_calls:
|
55
|
+
if tool.name == "terminate":
|
56
|
+
has_terminate_tool = True
|
57
|
+
logger.info(f"step={i} find terminate tool, break.")
|
58
|
+
break
|
59
|
+
|
60
|
+
if not has_terminate_tool and not assistant_message.tool_calls:
|
61
|
+
logger.warning(f"【bugfix】step={i} no tools, break.")
|
62
|
+
has_terminate_tool = True
|
63
|
+
|
64
|
+
for j, tool_call in enumerate(assistant_message.tool_calls):
|
65
|
+
logger.info(f"submit step={i} tool_calls.name={tool_call.name} argument_dict={tool_call.argument_dict}")
|
66
|
+
|
67
|
+
if tool_call.name not in tool_dict:
|
68
|
+
continue
|
69
|
+
|
70
|
+
self.submit_task(tool_dict[tool_call.name].execute, **tool_call.argument_dict)
|
71
|
+
time.sleep(1)
|
72
|
+
|
73
|
+
if not has_terminate_tool:
|
74
|
+
user_content_list = []
|
75
|
+
for tool_result, tool_call in zip(self.join_task(), assistant_message.tool_calls):
|
76
|
+
logger.info(f"submit step={i} tool_calls.name={tool_call.name} tool_result={tool_result}")
|
77
|
+
assert isinstance(tool_result, str)
|
78
|
+
user_content_list.append(f"<tool_response>\n{tool_result}\n</tool_response>")
|
79
|
+
user_content_list.append(self.prompt_format(prompt_name="next_prompt"))
|
80
|
+
assistant_message.tool_calls.clear()
|
81
|
+
messages.append(Message(role=Role.USER, content="\n".join(user_content_list)))
|
82
|
+
|
83
|
+
else:
|
84
|
+
assistant_message.tool_calls.clear()
|
85
|
+
messages.append(Message(role=Role.USER, content=self.prompt_format(prompt_name="final_prompt")))
|
86
|
+
|
87
|
+
response.messages = messages
|
88
|
+
response.answer = response.messages[-1].content
|
@@ -0,0 +1,28 @@
|
|
1
|
+
role_prompt: |
|
2
|
+
You are a helpful assistant.
|
3
|
+
The current time is {time}.
|
4
|
+
Please proactively choose the most suitable tool or combination of tools based on the user's question, including {tools} etc.
|
5
|
+
Please first think about how to break down the problem into subtasks, what tools and parameters should be used for each subtask, and finally provide the tool call name and parameters.
|
6
|
+
Try calling the same tool multiple times with different parameters to obtain information from various perspectives.
|
7
|
+
Please determine the response language based on the language of the user's question.
|
8
|
+
|
9
|
+
{query}
|
10
|
+
|
11
|
+
# write a complete and rigorous report to answer user's questions based on the context.
|
12
|
+
next_prompt: |
|
13
|
+
Think based on the current content and the user's question: Is the current context sufficient to answer the user's question?
|
14
|
+
|
15
|
+
- If the current context is not sufficient to answer the user's question, consider what information is missing.
|
16
|
+
Re-plan and think about how to break down the missing information into subtasks.
|
17
|
+
For each subtask, determine what tools and parameters should be used for the query.
|
18
|
+
Please first provide the reasoning process, then give the tool call name and parameters.
|
19
|
+
|
20
|
+
- If the current context is sufficient to answer the user's question, use the **terminate** tool.
|
21
|
+
|
22
|
+
# Please determine the response language based on the language of the user's question.
|
23
|
+
final_prompt: |
|
24
|
+
Please integrate the context and provide a complete answer to the user's question.
|
25
|
+
|
26
|
+
# User's Question
|
27
|
+
{query}
|
28
|
+
|
@@ -0,0 +1,13 @@
|
|
1
|
+
"""
|
2
|
+
1. retrieve:
|
3
|
+
search: query(context), workspace_id(request), top_k(request)
|
4
|
+
2. summary:
|
5
|
+
insert: nodes(context), workspace_id(request)
|
6
|
+
delete: ids(context), workspace_id(request)
|
7
|
+
search: query(context), workspace_id(request), top_k(request.config.op)
|
8
|
+
3. vector:
|
9
|
+
dump: workspace_id(request), path(str), max_size(int)
|
10
|
+
load: workspace_id(request), path(str)
|
11
|
+
delete: workspace_id(request)
|
12
|
+
copy: source_id, target_id, max_size(int)
|
13
|
+
"""
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from loguru import logger
|
4
|
+
|
5
|
+
from llmflow.op import OP_REGISTRY
|
6
|
+
from llmflow.op.base_op import BaseOp
|
7
|
+
from llmflow.schema.experience import BaseExperience, vector_node_to_experience
|
8
|
+
from llmflow.schema.request import RetrieverRequest
|
9
|
+
from llmflow.schema.response import RetrieverResponse
|
10
|
+
from llmflow.schema.vector_node import VectorNode
|
11
|
+
|
12
|
+
|
13
|
+
@OP_REGISTRY.register()
|
14
|
+
class RecallVectorStoreOp(BaseOp):
|
15
|
+
SEARCH_QUERY = "search_query"
|
16
|
+
SEARCH_MESSAGE = "search_message"
|
17
|
+
|
18
|
+
def execute(self):
|
19
|
+
# get query
|
20
|
+
query = self.context.get_context(self.SEARCH_QUERY)
|
21
|
+
assert query, "query should be not empty!"
|
22
|
+
|
23
|
+
# retrieve from vector store
|
24
|
+
request: RetrieverRequest = self.context.request
|
25
|
+
nodes: List[VectorNode] = self.vector_store.search(query=query,
|
26
|
+
workspace_id=request.workspace_id,
|
27
|
+
top_k=request.top_k)
|
28
|
+
|
29
|
+
# convert to experience, filter duplicate
|
30
|
+
experience_list: List[BaseExperience] = []
|
31
|
+
experience_content_list: List[str] = []
|
32
|
+
for node in nodes:
|
33
|
+
experience: BaseExperience = vector_node_to_experience(node)
|
34
|
+
if experience.content not in experience_content_list:
|
35
|
+
experience_list.append(experience)
|
36
|
+
experience_content_list.append(experience.content)
|
37
|
+
experience_size = len(experience_list)
|
38
|
+
logger.info(f"retrieve experience size={experience_size}")
|
39
|
+
|
40
|
+
# filter by score
|
41
|
+
threshold_score: float | None = self.op_params.get("threshold_score", None)
|
42
|
+
if threshold_score is not None:
|
43
|
+
experience_list = [e for e in experience_list if e.score >= threshold_score or e.score is None]
|
44
|
+
logger.info(f"after filter by threshold_score size={len(experience_list)}")
|
45
|
+
|
46
|
+
# set response
|
47
|
+
response: RetrieverResponse = self.context.response
|
48
|
+
response.experience_list = experience_list
|
@@ -0,0 +1,28 @@
|
|
1
|
+
import json
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
from loguru import logger
|
5
|
+
|
6
|
+
from llmflow.op import OP_REGISTRY
|
7
|
+
from llmflow.op.base_op import BaseOp
|
8
|
+
from llmflow.schema.experience import BaseExperience
|
9
|
+
from llmflow.schema.request import BaseRequest
|
10
|
+
from llmflow.schema.vector_node import VectorNode
|
11
|
+
|
12
|
+
|
13
|
+
@OP_REGISTRY.register()
|
14
|
+
class UpdateVectorStoreOp(BaseOp):
|
15
|
+
|
16
|
+
def execute(self):
|
17
|
+
request: BaseRequest = self.context.request
|
18
|
+
|
19
|
+
experience_ids: List[str] | None = self.context.response.deleted_experience_ids
|
20
|
+
if experience_ids:
|
21
|
+
self.vector_store.delete(node_ids=experience_ids, workspace_id=request.workspace_id)
|
22
|
+
logger.info(f"delete experience_ids={json.dumps(experience_ids, indent=2)}")
|
23
|
+
|
24
|
+
insert_experience_list: List[BaseExperience] | None = self.context.response.experience_list
|
25
|
+
if insert_experience_list:
|
26
|
+
insert_nodes: List[VectorNode] = [x.to_vector_node() for x in insert_experience_list]
|
27
|
+
self.vector_store.insert(nodes=insert_nodes, workspace_id=request.workspace_id)
|
28
|
+
logger.info(f"insert insert_node.size={len(insert_nodes)}")
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from llmflow.op import OP_REGISTRY
|
2
|
+
from llmflow.op.base_op import BaseOp
|
3
|
+
from llmflow.schema.experience import vector_node_to_experience, dict_to_experience, BaseExperience
|
4
|
+
from llmflow.schema.request import VectorStoreRequest
|
5
|
+
from llmflow.schema.response import VectorStoreResponse
|
6
|
+
from llmflow.schema.vector_node import VectorNode
|
7
|
+
|
8
|
+
|
9
|
+
@OP_REGISTRY.register()
|
10
|
+
class VectorStoreActionOp(BaseOp):
|
11
|
+
|
12
|
+
def execute(self):
|
13
|
+
request: VectorStoreRequest = self.context.request
|
14
|
+
response: VectorStoreResponse = self.context.response
|
15
|
+
|
16
|
+
if request.action == "copy":
|
17
|
+
result = self.vector_store.copy_workspace(src_workspace_id=request.src_workspace_id,
|
18
|
+
dest_workspace_id=request.workspace_id)
|
19
|
+
|
20
|
+
elif request.action == "delete":
|
21
|
+
result = self.vector_store.delete_workspace(workspace_id=request.workspace_id)
|
22
|
+
|
23
|
+
elif request.action == "dump":
|
24
|
+
def node_to_experience(node: VectorNode) -> dict:
|
25
|
+
return vector_node_to_experience(node).model_dump()
|
26
|
+
|
27
|
+
result = self.vector_store.dump_workspace(workspace_id=request.workspace_id,
|
28
|
+
path=request.path,
|
29
|
+
callback_fn=node_to_experience)
|
30
|
+
|
31
|
+
elif request.action == "load":
|
32
|
+
def experience_dict_to_node(experience_dict: dict) -> VectorNode:
|
33
|
+
experience: BaseExperience = dict_to_experience(experience_dict=experience_dict)
|
34
|
+
return experience.to_vector_node()
|
35
|
+
|
36
|
+
result = self.vector_store.load_workspace(workspace_id=request.workspace_id,
|
37
|
+
path=request.path,
|
38
|
+
callback_fn=experience_dict_to_node)
|
39
|
+
|
40
|
+
else:
|
41
|
+
raise ValueError(f"invalid action={request.action}")
|
42
|
+
|
43
|
+
if isinstance(result, dict):
|
44
|
+
response.metadata.update(result)
|
45
|
+
else:
|
46
|
+
response.metadata["result"] = str(result)
|
File without changes
|
@@ -0,0 +1,94 @@
|
|
1
|
+
from concurrent.futures import as_completed
|
2
|
+
from itertools import zip_longest
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from llmflow.op import OP_REGISTRY
|
8
|
+
from llmflow.op.base_op import BaseOp
|
9
|
+
from llmflow.pipeline.pipeline_context import PipelineContext
|
10
|
+
from llmflow.utils.timer import Timer, timer
|
11
|
+
|
12
|
+
|
13
|
+
class Pipeline:
|
14
|
+
seq_symbol: str = "->"
|
15
|
+
parallel_symbol: str = "|"
|
16
|
+
|
17
|
+
def __init__(self, pipeline: str, context: PipelineContext):
|
18
|
+
self.pipeline_list: List[str | List[str]] = self._parse_pipline(pipeline)
|
19
|
+
self.context: PipelineContext = context
|
20
|
+
|
21
|
+
def _parse_pipline(self, pipeline: str) -> List[str | List[str]]:
|
22
|
+
pipeline_list: List[str | List[str]] = []
|
23
|
+
|
24
|
+
for pipeline_split1 in pipeline.split("["):
|
25
|
+
for sub_pipeline in pipeline_split1.split("]"):
|
26
|
+
sub_pipeline = sub_pipeline.strip().strip(self.seq_symbol)
|
27
|
+
if not sub_pipeline:
|
28
|
+
continue
|
29
|
+
|
30
|
+
if self.parallel_symbol in sub_pipeline:
|
31
|
+
pipeline_list.append(sub_pipeline.split(self.parallel_symbol))
|
32
|
+
else:
|
33
|
+
pipeline_list.append(sub_pipeline)
|
34
|
+
logger.info(f"add sub_pipeline={sub_pipeline}")
|
35
|
+
return pipeline_list
|
36
|
+
|
37
|
+
def _execute_sub_pipeline(self, pipeline: str):
|
38
|
+
op_config_dict = self.context.app_config.op
|
39
|
+
for op in pipeline.split(self.seq_symbol):
|
40
|
+
op = op.strip()
|
41
|
+
if not op:
|
42
|
+
continue
|
43
|
+
|
44
|
+
assert op in op_config_dict, f"op={op} config is missing!"
|
45
|
+
op_config = op_config_dict[op]
|
46
|
+
|
47
|
+
assert op_config.backend in OP_REGISTRY, f"op={op} backend={op_config.backend} is not registered!"
|
48
|
+
op_cls = OP_REGISTRY[op_config.backend]
|
49
|
+
|
50
|
+
op_obj: BaseOp = op_cls(context=self.context, op_config=op_config)
|
51
|
+
op_obj.execute_wrap()
|
52
|
+
|
53
|
+
def _parse_sub_pipeline(self, pipeline: str):
|
54
|
+
for op in pipeline.split(self.seq_symbol):
|
55
|
+
op = op.strip()
|
56
|
+
if not op:
|
57
|
+
continue
|
58
|
+
|
59
|
+
yield op
|
60
|
+
|
61
|
+
def print_pipeline(self):
|
62
|
+
i: int = 0
|
63
|
+
for pipeline in self.pipeline_list:
|
64
|
+
if isinstance(pipeline, str):
|
65
|
+
for op in self._parse_sub_pipeline(pipeline):
|
66
|
+
i += 1
|
67
|
+
logger.info(f"stage_{i}: {op}")
|
68
|
+
|
69
|
+
elif isinstance(pipeline, list):
|
70
|
+
parallel_pipeline = [self._parse_sub_pipeline(x) for x in pipeline]
|
71
|
+
for op_list in zip_longest(*parallel_pipeline, fillvalue="-"):
|
72
|
+
i += 1
|
73
|
+
logger.info(f"stage{i}: {' | '.join(op_list)}")
|
74
|
+
else:
|
75
|
+
raise ValueError(f"unknown pipeline.type={type(pipeline)}")
|
76
|
+
|
77
|
+
@timer(name="pipeline.execute")
|
78
|
+
def __call__(self, enable_print: bool = True):
|
79
|
+
if enable_print:
|
80
|
+
self.print_pipeline()
|
81
|
+
|
82
|
+
for i, pipeline in enumerate(self.pipeline_list):
|
83
|
+
with Timer(f"step_{i}"):
|
84
|
+
if isinstance(pipeline, str):
|
85
|
+
self._execute_sub_pipeline(pipeline)
|
86
|
+
|
87
|
+
else:
|
88
|
+
future_list = []
|
89
|
+
for sub_pipeline in pipeline:
|
90
|
+
future = self.context.thread_pool.submit(self._execute_sub_pipeline, pipeline=sub_pipeline)
|
91
|
+
future_list.append(future)
|
92
|
+
|
93
|
+
for future in as_completed(future_list):
|
94
|
+
future.result()
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from concurrent.futures import ThreadPoolExecutor
|
2
|
+
from typing import Dict
|
3
|
+
|
4
|
+
from llmflow.schema.app_config import AppConfig
|
5
|
+
from llmflow.vector_store.base_vector_store import BaseVectorStore
|
6
|
+
|
7
|
+
|
8
|
+
class PipelineContext:
|
9
|
+
|
10
|
+
def __init__(self, **kwargs):
|
11
|
+
self._context: dict = {**kwargs}
|
12
|
+
|
13
|
+
def get_context(self, key: str, default=None):
|
14
|
+
return self._context.get(key, default)
|
15
|
+
|
16
|
+
def set_context(self, key: str, value):
|
17
|
+
self._context[key] = value
|
18
|
+
|
19
|
+
@property
|
20
|
+
def request(self):
|
21
|
+
return self._context["request"]
|
22
|
+
|
23
|
+
@property
|
24
|
+
def response(self):
|
25
|
+
return self._context["response"]
|
26
|
+
|
27
|
+
@property
|
28
|
+
def app_config(self) -> AppConfig:
|
29
|
+
return self._context["app_config"]
|
30
|
+
|
31
|
+
@property
|
32
|
+
def thread_pool(self) -> ThreadPoolExecutor:
|
33
|
+
return self._context["thread_pool"]
|
34
|
+
|
35
|
+
@property
|
36
|
+
def vector_store_dict(self) -> Dict[str, BaseVectorStore]:
|
37
|
+
return self._context["vector_store_dict"]
|
File without changes
|
@@ -0,0 +1,69 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import Dict
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class HttpServiceConfig:
|
7
|
+
host: str = field(default="0.0.0.0")
|
8
|
+
port: int = field(default=8001)
|
9
|
+
timeout_keep_alive: int = field(default=600)
|
10
|
+
limit_concurrency: int = field(default=64)
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class ThreadPoolConfig:
|
15
|
+
max_workers: int = field(default=10)
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class APIConfig:
|
20
|
+
retriever: str = field(default="")
|
21
|
+
summarizer: str = field(default="")
|
22
|
+
vector_store: str = field(default="")
|
23
|
+
agent: str = field(default="")
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class OpConfig:
|
28
|
+
backend: str = field(default="")
|
29
|
+
prompt_file_path: str = field(default="")
|
30
|
+
prompt_dict: dict = field(default_factory=dict)
|
31
|
+
llm: str = field(default="")
|
32
|
+
embedding_model: str = field(default="")
|
33
|
+
vector_store: str = field(default="")
|
34
|
+
params: dict = field(default_factory=dict)
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class LLMConfig:
|
39
|
+
backend: str = field(default="")
|
40
|
+
model_name: str = field(default="")
|
41
|
+
params: dict = field(default_factory=dict)
|
42
|
+
|
43
|
+
|
44
|
+
@dataclass
|
45
|
+
class EmbeddingModelConfig:
|
46
|
+
backend: str = field(default="")
|
47
|
+
model_name: str = field(default="")
|
48
|
+
params: dict = field(default_factory=dict)
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass
|
52
|
+
class VectorStoreConfig:
|
53
|
+
backend: str = field(default="")
|
54
|
+
embedding_model: str = field(default="")
|
55
|
+
params: dict = field(default_factory=dict)
|
56
|
+
|
57
|
+
|
58
|
+
@dataclass
|
59
|
+
class AppConfig:
|
60
|
+
pre_defined_config: str = field(default="mock_config")
|
61
|
+
config_path: str = field(default="")
|
62
|
+
mcp_transport: str = field(default="sse")
|
63
|
+
http_service: HttpServiceConfig = field(default_factory=HttpServiceConfig)
|
64
|
+
thread_pool: ThreadPoolConfig = field(default_factory=ThreadPoolConfig)
|
65
|
+
api: APIConfig = field(default_factory=APIConfig)
|
66
|
+
op: Dict[str, OpConfig] = field(default_factory=dict)
|
67
|
+
llm: Dict[str, LLMConfig] = field(default_factory=dict)
|
68
|
+
embedding_model: Dict[str, EmbeddingModelConfig] = field(default_factory=dict)
|
69
|
+
vector_store: Dict[str, VectorStoreConfig] = field(default_factory=dict)
|