alayaflow 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. alayaflow-0.1.2/.alaya.ai/alayaflow/workflows/autotable/1.0.0/metadata.json +9 -0
  2. alayaflow-0.1.2/.alaya.ai/alayaflow/workflows/autotable/1.0.0/requirements.txt +11 -0
  3. alayaflow-0.1.2/.alaya.ai/alayaflow/workflows/autotable/1.0.0/workflow.py +400 -0
  4. alayaflow-0.1.2/.alaya.ai/alayaflow/workflows/simple_chat/1.0.0/metadata.json +9 -0
  5. alayaflow-0.1.2/.alaya.ai/alayaflow/workflows/simple_chat/1.0.0/metadata.py +16 -0
  6. alayaflow-0.1.2/.alaya.ai/alayaflow/workflows/simple_chat/1.0.0/requirements.txt +11 -0
  7. alayaflow-0.1.2/.alaya.ai/alayaflow/workflows/simple_chat/1.0.0/schemas.py +32 -0
  8. alayaflow-0.1.2/.alaya.ai/alayaflow/workflows/simple_chat/1.0.0/workflow.py +94 -0
  9. alayaflow-0.1.2/.github/workflows/pr-test.yml +41 -0
  10. alayaflow-0.1.2/.gitignore +240 -0
  11. alayaflow-0.1.2/LICENSE +661 -0
  12. alayaflow-0.1.2/PKG-INFO +99 -0
  13. alayaflow-0.1.2/README.md +79 -0
  14. alayaflow-0.1.2/examples/autotable_demo.py +60 -0
  15. alayaflow-0.1.2/examples/chat_demo.py +87 -0
  16. alayaflow-0.1.2/pyproject.origin.toml +47 -0
  17. alayaflow-0.1.2/pyproject.toml +47 -0
  18. alayaflow-0.1.2/src/alayaflow/__init__.py +5 -0
  19. alayaflow-0.1.2/src/alayaflow/api/__init__.py +8 -0
  20. alayaflow-0.1.2/src/alayaflow/api/api_singleton.py +99 -0
  21. alayaflow-0.1.2/src/alayaflow/clients/alayamem/base_client.py +19 -0
  22. alayaflow-0.1.2/src/alayaflow/clients/alayamem/http_client.py +64 -0
  23. alayaflow-0.1.2/src/alayaflow/common/config.py +106 -0
  24. alayaflow-0.1.2/src/alayaflow/component/__init__.py +0 -0
  25. alayaflow-0.1.2/src/alayaflow/component/chat_model.py +19 -0
  26. alayaflow-0.1.2/src/alayaflow/component/intent_classifier.py +94 -0
  27. alayaflow-0.1.2/src/alayaflow/component/langflow/__init__.py +0 -0
  28. alayaflow-0.1.2/src/alayaflow/component/langflow/intent_classifier.py +83 -0
  29. alayaflow-0.1.2/src/alayaflow/component/llm_node.py +114 -0
  30. alayaflow-0.1.2/src/alayaflow/component/memory.py +50 -0
  31. alayaflow-0.1.2/src/alayaflow/component/model/__init__.py +8 -0
  32. alayaflow-0.1.2/src/alayaflow/component/model/model_manager.py +60 -0
  33. alayaflow-0.1.2/src/alayaflow/component/model/schemas.py +33 -0
  34. alayaflow-0.1.2/src/alayaflow/component/retrieve_node.py +11 -0
  35. alayaflow-0.1.2/src/alayaflow/component/search_node.py +147 -0
  36. alayaflow-0.1.2/src/alayaflow/component/web_search.py +126 -0
  37. alayaflow-0.1.2/src/alayaflow/execution/__init__.py +6 -0
  38. alayaflow-0.1.2/src/alayaflow/execution/env_manager.py +425 -0
  39. alayaflow-0.1.2/src/alayaflow/execution/executor_manager.py +59 -0
  40. alayaflow-0.1.2/src/alayaflow/execution/executors/__init__.py +9 -0
  41. alayaflow-0.1.2/src/alayaflow/execution/executors/base_executor.py +9 -0
  42. alayaflow-0.1.2/src/alayaflow/execution/executors/naive_executor.py +119 -0
  43. alayaflow-0.1.2/src/alayaflow/execution/executors/uv_executor.py +125 -0
  44. alayaflow-0.1.2/src/alayaflow/execution/executors/worker_executor.py +12 -0
  45. alayaflow-0.1.2/src/alayaflow/execution/langfuse_tracing.py +104 -0
  46. alayaflow-0.1.2/src/alayaflow/execution/workflow_runner.py +98 -0
  47. alayaflow-0.1.2/src/alayaflow/utils/singleton.py +14 -0
  48. alayaflow-0.1.2/src/alayaflow/workflow/__init__.py +10 -0
  49. alayaflow-0.1.2/src/alayaflow/workflow/runnable/__init__.py +7 -0
  50. alayaflow-0.1.2/src/alayaflow/workflow/runnable/base_runnable_workflow.py +21 -0
  51. alayaflow-0.1.2/src/alayaflow/workflow/runnable/state_graph_runnable_workflow.py +29 -0
  52. alayaflow-0.1.2/src/alayaflow/workflow/workflow_info.py +50 -0
  53. alayaflow-0.1.2/src/alayaflow/workflow/workflow_loader.py +172 -0
  54. alayaflow-0.1.2/src/alayaflow/workflow/workflow_manager.py +269 -0
  55. alayaflow-0.1.2/tests/__init__.py +1 -0
  56. alayaflow-0.1.2/tests/clients/__init__.py +1 -0
  57. alayaflow-0.1.2/tests/clients/conftest.py +9 -0
  58. alayaflow-0.1.2/tests/clients/test_alayamem.py +57 -0
  59. alayaflow-0.1.2/tests/component/test_intent_classifier.py +236 -0
  60. alayaflow-0.1.2/tests/component/test_llm_node.py +159 -0
  61. alayaflow-0.1.2/tests/execution/test_env_reuse.py +243 -0
  62. alayaflow-0.1.2/tests/workflow/__init__.py +0 -0
  63. alayaflow-0.1.2/tests/workflow/conftest.py +14 -0
  64. alayaflow-0.1.2/tests/workflow/test_workflow_loader.py +38 -0
  65. alayaflow-0.1.2/uv.lock +2728 -0
@@ -0,0 +1,9 @@
1
+ {
2
+ "id": "autotable",
3
+ "name": "RAG 并发信息抽取工作流",
4
+ "description": "基于 LangGraph Map-Reduce 架构的高性能抽取流程。集成信号量限流(Semaphore)、JSON 结构化校验、文档截断及错误兜底机制。",
5
+ "version": "1.0.0",
6
+ "tags": ["rag", "extraction", "langgraph", "json-mode"],
7
+ "entry_file": "workflow.py",
8
+ "entry_point": "create_graph"
9
+ }
@@ -0,0 +1,11 @@
1
+ # LangGraph 核心依赖
2
+ langgraph>=0.2.0
3
+
4
+ # LangChain Community (用于 ChatOpenAI)
5
+ langchain-community>=0.3.0
6
+
7
+ # OpenAI SDK (DeepSeek API 兼容 OpenAI 格式)
8
+ openai>=1.0.0
9
+
10
+ # Langfuse
11
+ langfuse>=3.0.0,<4.0.0
@@ -0,0 +1,400 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, TypedDict, Annotated, Union, TypeAlias, Tuple
4
+ from collections import defaultdict
5
+ from threading import Semaphore
6
+
7
+ from langgraph.graph import StateGraph, START, END
8
+ from langgraph.types import Send
9
+ from langchain_core.runnables import RunnableConfig
10
+
11
+ from alayaflow.component.llm_node import LLMComponent, ResponseFormat
12
+ from alayaflow.clients.alayamem.http_client import HttpAlayaMemClient
13
+ from alayaflow.component.retrieve_node import RetrieveComponent
14
+
15
+
16
+ FieldSpec: TypeAlias = Union[str, Dict[str, List["FieldSpec"]]] # 递归:dict -> list[FieldSpec]
17
+
18
+ def merge_dicts(a: Dict, b: Dict) -> Dict:
19
+ return {**a, **b}
20
+
21
+ def deep_merge(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]:
22
+ out = dict(a or {})
23
+ for k, v in (b or {}).items():
24
+ if k in out and isinstance(out[k], dict) and isinstance(v, dict):
25
+ out[k] = deep_merge(out[k], v)
26
+ else:
27
+ out[k] = v
28
+ return out
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class GroupTask:
33
+ path: Tuple[str, ...] # 父路径,如 ("个人信息","联系方式");根为 ()
34
+ keys: Tuple[str, ...] # 该路径下需要抽取的叶子字段名
35
+
36
+
37
+ class OverallState(TypedDict):
38
+ fields: List[FieldSpec] # 输入模板(递归)
39
+ tasks: List[GroupTask] # 规划出的任务列表
40
+
41
+ # 调试信息:每个任务的检索片段
42
+ context_by_task: Annotated[Dict[str, List[str]], merge_dicts]
43
+
44
+ # 最终值树:通过 deep_merge reducer 并发合并 patch
45
+ final_result: Annotated[Dict[str, Any], deep_merge]
46
+
47
+ errors: Annotated[Dict[str, str], merge_dicts]
48
+
49
+
50
+ class TaskState(TypedDict):
51
+ task: GroupTask
52
+
53
+
54
+
55
+ def _as_list(x: Any) -> List[Any]:
56
+ if x is None:
57
+ return []
58
+ if isinstance(x, list):
59
+ return x
60
+ return [x]
61
+
62
+ def flatten_leaf_tasks(specs: List[FieldSpec], base_path: Optional[List[str]] = None) -> List[Tuple[Tuple[str, ...], str]]:
63
+ """
64
+ 返回:[(path_tuple, leaf_key), ...]
65
+ """
66
+ base_path = base_path or []
67
+ out: List[Tuple[Tuple[str, ...], str]] = []
68
+
69
+ for item in specs or []:
70
+ if isinstance(item, str):
71
+ out.append((tuple(base_path), item))
72
+ continue
73
+
74
+ if isinstance(item, dict):
75
+ for parent, children in item.items():
76
+ for child in _as_list(children):
77
+ if isinstance(child, str):
78
+ out.append((tuple(base_path + [parent]), child))
79
+ elif isinstance(child, dict):
80
+ out.extend(flatten_leaf_tasks([child], base_path + [parent]))
81
+ else:
82
+ pass
83
+ continue
84
+
85
+ return out
86
+
87
+ def plan_node(state: OverallState, config: RunnableConfig):
88
+ leaf = flatten_leaf_tasks(state["fields"])
89
+ grouped: Dict[Tuple[str, ...], List[str]] = defaultdict(list)
90
+ for path, key in leaf:
91
+ grouped[path].append(key)
92
+
93
+ tasks: List[GroupTask] = []
94
+ for path, keys in grouped.items():
95
+ # 去重保持顺序
96
+ seen = set()
97
+ uniq = []
98
+ for k in keys:
99
+ if k not in seen:
100
+ seen.add(k)
101
+ uniq.append(k)
102
+ tasks.append(GroupTask(path=path, keys=tuple(uniq)))
103
+
104
+ # 可选:让任务顺序稳定(不影响并发结果,只影响日志观感)
105
+ tasks.sort(key=lambda t: (len(t.path), t.path))
106
+ return {"tasks": tasks}
107
+
108
+
109
+ def map_tasks(state: OverallState):
110
+ return [Send("extract_task", {"task": t}) for t in state["tasks"]]
111
+
112
+
113
+
114
+ def make_patch(path: Tuple[str, ...], kv: Dict[str, str]) -> Dict[str, Any]:
115
+ """
116
+ path=("个人信息","联系方式"), kv={"电话":"..","邮箱":".."} =>
117
+ {"个人信息":{"联系方式":{"电话":"..","邮箱":".."}}}
118
+ """
119
+ node: Dict[str, Any] = dict(kv)
120
+ for p in reversed(path):
121
+ node = {p: node}
122
+ return node
123
+
124
+
125
+
126
+ def build_system_prompt(keys: list[str]) -> str:
127
+ keys_str = ", ".join(keys)
128
+
129
+ return f"""
130
+ 你是一个严谨的“局部字段抽取器”(table patch extractor)。
131
+
132
+ 你的任务是:**只为指定字段抽取值**,严格依据提供的知识片段,不得猜测或编造。
133
+
134
+ 通用规则:
135
+ 1. 输出必须是严格合法 JSON,不允许包含解释、Markdown、代码块或多余文本。
136
+ 2. **只允许输出以下字段(不多不少)**:{keys_str}
137
+ 3. 所有字段值必须是字符串。
138
+ 4. 找不到 / 不确定 / 空值 / 占位符 → 必须输出空字符串 ""。
139
+ 5. 字段名可能存在空格或轻微变体(如“姓 名”≈“姓名”),允许智能匹配,但不得扩展到未指定字段。
140
+
141
+ 长文本字段格式规则(必须遵守):
142
+ - 当字段内容包含**多个条目、多个时间段或多段经历**时:
143
+ - 必须使用序号列表格式。
144
+ - **每个条目占一行,条目之间必须使用 "\n" 换行符分隔。**
145
+ - 不允许使用分号、顿号、逗号等方式合并多个条目到同一行。
146
+ - 示例正确格式:
147
+ "1.第一条内容\n2.第二条内容\n3.第三条内容"
148
+
149
+ 表格单元格理解规则(重要):
150
+ - 知识片段可能来自表格,每行使用 " | " 分隔单元格。
151
+ - "<空>" 表示空单元格,对应值为 ""。
152
+ - 字段名后不一定是值:
153
+ - 若字段名后是 "<空>" → 值为 ""。
154
+ - 若字段名后是另一个字段名 → 继续向后寻找第一个“非字段名 / 非占位符”的单元格作为值。
155
+ - 示例:"字段A | 字段B | 值" → 字段A="", 字段B="值"。
156
+
157
+ 占位符识别:
158
+ - 若候选值是模板占位符或签字日期类文本
159
+ (如“签字: 年 月 日”“学院盖章: 年 月 日”等),必须返回 ""。
160
+ """.strip()
161
+
162
+
163
+ def build_user_prompt(
164
+ content_text: str,
165
+ path: list[str],
166
+ keys: list[str],
167
+ ) -> str:
168
+ path_str = " / ".join(path) if path else "<root>"
169
+ keys_str = ", ".join(keys)
170
+
171
+ json_skeleton = "{\n" + ",\n".join([f' "{k}": ""' for k in keys]) + "\n}"
172
+
173
+ return f"""
174
+ 【本次任务定位】
175
+ 字段路径(仅用于语义定位,不要输出):{path_str}
176
+ 需要抽取的字段:{keys_str}
177
+
178
+ 【知识库片段】
179
+ {content_text}
180
+
181
+ 【输出要求】
182
+ - 只输出一个 JSON 对象
183
+ - key 必须严格为:{keys_str}
184
+ - 无法确定 / 空值 / 占位符 → 输出 ""
185
+
186
+ 【JSON 输出模板】
187
+ {json_skeleton}
188
+ """.strip()
189
+
190
+
191
+ def create_extract_task_node(
192
+ client: HttpAlayaMemClient,
193
+ *,
194
+ max_concurrency: int = 10,
195
+ top_k: int = 5,
196
+ max_doc_chars: int = 400,
197
+ ):
198
+ limiter = Semaphore(max_concurrency)
199
+
200
+ def slim_docs(docs: List[str]) -> List[str]:
201
+ out = []
202
+ for d in docs or []:
203
+ s = str(d)
204
+ if len(s) > max_doc_chars:
205
+ s = s[:max_doc_chars] + "…"
206
+ out.append(s)
207
+ return out
208
+
209
+ def node(state: TaskState, config: RunnableConfig):
210
+ task = state["task"]
211
+ path = task.path
212
+ keys = list(task.keys)
213
+
214
+ task_id = f"{'/'.join(path) or '<root>'}:{','.join(keys)}"
215
+
216
+ # 默认 patch:保证结构稳定(缺失也填空)
217
+ default_kv = {k: "" for k in keys}
218
+ default_patch = make_patch(path, default_kv)
219
+
220
+ try:
221
+ with limiter:
222
+ # 从 config 中获取 collection_name(运行时参数)
223
+ config_dict = config.get("configurable", {}) if isinstance(config, dict) else {}
224
+ collection_name = config_dict.get("collection_name", "file_watcher_collection")
225
+
226
+ # 1) 检索 query:路径信息 + keys
227
+ # path 越深,越应该把上层标题带进去提升命中
228
+ query_parts = list(path) + keys
229
+ query = ";".join([p for p in query_parts if p])
230
+
231
+ retrieve_component = RetrieveComponent(client=client)
232
+ docs = retrieve_component(query=query, collection_name=collection_name, limit=top_k)
233
+ docs = slim_docs(docs)
234
+
235
+ # 没 docs:直接返回默认
236
+ if not docs:
237
+ return {
238
+ "context_by_task": {task_id: []},
239
+ "final_result": default_patch,
240
+ }
241
+
242
+ formatted_context = "\n\n".join(
243
+ [f"片段 {i+1}: {doc}" for i, doc in enumerate(docs)]
244
+ )
245
+
246
+ # 2) 一次性抽取 keys(严格 JSON object)
247
+ json_skeleton = "{\n" + ",\n".join([f' "{k}": ""' for k in keys]) + "\n}"
248
+
249
+ system_prompt = build_system_prompt(keys)
250
+
251
+ user_prompt = build_user_prompt(formatted_context, path, keys)
252
+
253
+ llm = LLMComponent(
254
+ model_name="deepseek-chat",
255
+ system_prompt=system_prompt,
256
+ prompt=user_prompt,
257
+ response_format=ResponseFormat.JSON,
258
+ temperature=0.0,
259
+ )
260
+
261
+ msg = llm()
262
+ obj = json.loads(msg.content)
263
+
264
+ extracted = {}
265
+ for k in keys:
266
+ v = obj.get(k, "")
267
+ extracted[k] = (str(v).strip() if v is not None else "")
268
+
269
+ patch = make_patch(path, extracted)
270
+
271
+ return {
272
+ "context_by_task": {task_id: docs},
273
+ "final_result": patch,
274
+ }
275
+
276
+ except Exception as e:
277
+ return {
278
+ "context_by_task": {task_id: []},
279
+ "final_result": default_patch,
280
+ "errors": {task_id: f"{type(e).__name__}: {e}"},
281
+ }
282
+
283
+ return node
284
+
285
+
286
+
287
+ def validate_node(state: OverallState, config: RunnableConfig):
288
+ # 简单缺失检查:把 tasks 展开期望字段,看看 final_result 是否为空
289
+ res = state.get("final_result", {}) or {}
290
+ missing = []
291
+
292
+ def get_in(d: Dict[str, Any], path: Tuple[str, ...]) -> Dict[str, Any]:
293
+ cur = d
294
+ for p in path:
295
+ if not isinstance(cur, dict):
296
+ return {}
297
+ cur = cur.get(p, {})
298
+ return cur if isinstance(cur, dict) else {}
299
+
300
+ for t in state["tasks"]:
301
+ scope = get_in(res, t.path)
302
+ for k in t.keys:
303
+ if not str(scope.get(k, "")).strip():
304
+ missing.append((".".join(t.path + (k,))) if t.path else k)
305
+
306
+ if missing:
307
+ return {"errors": {"__missing__": ";".join(missing)}}
308
+ return {}
309
+
310
+
311
+ # -------------------------
312
+ # Build graph
313
+ # -------------------------
314
+ def create_graph(init_args: Dict[str, Any]):
315
+ client = HttpAlayaMemClient(init_args["alayamem_url"])
316
+ g = StateGraph(OverallState)
317
+
318
+ g.add_node("plan", plan_node)
319
+ g.add_node("extract_task", create_extract_task_node(client, max_concurrency=10, top_k=3))
320
+ g.add_node("validate", validate_node)
321
+
322
+ g.add_edge(START, "plan")
323
+ g.add_conditional_edges("plan", map_tasks, ["extract_task"])
324
+ g.add_edge("extract_task", "validate")
325
+ g.add_edge("validate", END)
326
+
327
+ return g.compile()
328
+
329
+
330
+ if __name__ == "__main__":
331
+ app = create_graph({"alayamem_url": "http://10.16.70.46:5555"})
332
+
333
+ input_data: OverallState = {
334
+ "fields": [
335
+ {
336
+ "申请人信息": [
337
+ "姓名",
338
+ "性别",
339
+ "出生年月",
340
+ "民族",
341
+ "学位",
342
+ "职称",
343
+ "是否在站博士后",
344
+ "电子邮箱",
345
+ "办公电话",
346
+ "国别或地区",
347
+ "申请人类别",
348
+ "工作单位",
349
+ "主要研究领域"
350
+ ]
351
+ },
352
+ {
353
+ "依托单位信息": [
354
+ "名称",
355
+ "联系人",
356
+ "电子邮箱",
357
+ "电话",
358
+ "网站地址"
359
+ ]
360
+ },
361
+ {
362
+ "合作研究单位信息": [
363
+ "单位名称"
364
+ ]
365
+ },
366
+ {
367
+ "项目基本信息": [
368
+ "项目名称",
369
+ "英文名称",
370
+ "资助类别",
371
+ "亚类说明",
372
+ "附注说明",
373
+ "申请代码",
374
+ "研究期限",
375
+ "研究方向",
376
+ "申请资助经费",
377
+ "研究属性",
378
+ "中文关键词",
379
+ "英文关键词"
380
+ ]
381
+ },
382
+ "中文摘要",
383
+ "英文摘要"
384
+ ],
385
+ "tasks": [],
386
+ "context_by_task": {},
387
+ "final_result": {},
388
+ "errors": {},
389
+ }
390
+
391
+ config = {
392
+ "configurable": {
393
+ "collection_name": "file_watcher_collection",
394
+ }
395
+ }
396
+ out = app.invoke(input_data, config=config)
397
+ print("final_result:")
398
+ print(json.dumps(out["final_result"], ensure_ascii=False, indent=2))
399
+ print("\nerrors:")
400
+ print(json.dumps(out["errors"], ensure_ascii=False, indent=2))
@@ -0,0 +1,9 @@
1
+ {
2
+ "id": "simple_chat",
3
+ "name": "Simple Chatbot",
4
+ "description": "一个简单的 LLM 对话工作流示例",
5
+ "version": "1.0.0",
6
+ "tags": ["chat", "basic"],
7
+ "entry_file": "workflow.py",
8
+ "entry_point": "create_graph"
9
+ }
@@ -0,0 +1,16 @@
1
+ from pathlib import Path
2
+
3
+ from alayaflow.workflow import WorkflowInfo
4
+
5
+ def get_metadata():
6
+ meta = {
7
+ "id": "simple_chat",
8
+ "name": "Simple Chatbot",
9
+ "description": "一个简单的 LLM 对话工作流示例",
10
+ "version": "1.0.0",
11
+ "tags": ["chat", "basic"],
12
+ "entry_file": "workflow.py",
13
+ "entry_point": "create_graph",
14
+ "wf_dir": Path(__file__).parent
15
+ }
16
+ return WorkflowInfo(**meta)
@@ -0,0 +1,11 @@
1
+ # LangGraph 核心依赖
2
+ langgraph>=0.2.0
3
+
4
+ # LangChain Community (用于 ChatOpenAI)
5
+ langchain-community>=0.3.0
6
+
7
+ # OpenAI SDK (DeepSeek API 兼容 OpenAI 格式)
8
+ openai>=1.0.0
9
+
10
+ # Langfuse
11
+ langfuse>=3.0.0,<4.0.0
@@ -0,0 +1,32 @@
1
+ from typing import TypedDict, List, Optional
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from langchain_core.messages import BaseMessage, AIMessageChunk
6
+
7
+
8
+ class WorkflowInitArgs(BaseModel):
9
+ alayamem_url: str = Field(..., description="AlayaMem URL")
10
+
11
+
12
+ class Input(BaseModel):
13
+ messages: List[BaseMessage] = Field(..., description="List of input messages")
14
+
15
+
16
+ class WorkflowContext(BaseModel):
17
+ user_id: str = Field(..., description="User ID")
18
+ session_id: str = Field(..., description="Session ID")
19
+ chat_model_id: str = Field(..., description="Chat Model ID")
20
+
21
+
22
+ class Output(BaseModel):
23
+ chat_response: dict = Field(..., description="Chat response")
24
+
25
+
26
+ class WorkflowState(TypedDict):
27
+ messages: List[BaseMessage]
28
+ memory_initialized: bool = False
29
+ retrieved_docs: Optional[List[str]]
30
+ stream_chunks: List[AIMessageChunk] = []
31
+ chat_response: Optional[dict]
32
+ context: Optional[str]
@@ -0,0 +1,94 @@
1
+ from langgraph.graph import StateGraph, START, END
2
+ from langgraph.runtime import Runtime
3
+
4
+ from alayaflow.component.memory import init_memory, query_message, add_message, query_vdb_message
5
+ from alayaflow.component.model import ModelManager
6
+
7
+ from .schemas import WorkflowInitArgs, WorkflowState, WorkflowContext, Input, Output
8
+
9
+ def mk_init_memory_node(alayamem_url: str):
10
+ def init_memory_node(state: WorkflowState, runtime: Runtime[WorkflowContext]):
11
+ user_id = runtime.context.user_id
12
+ session_id = runtime.context.session_id
13
+ original_result = init_memory(alayamem_url, user_id, session_id)
14
+ updated_state = state.copy()
15
+ updated_state["memory_initialized"] = original_result.get("status", "") == "success"
16
+ return updated_state
17
+ return init_memory_node
18
+
19
+ # Keep for integration
20
+ # def mk_query_message_node(alayamem_url: str):
21
+ # def query_message_node(state: WorkflowState):
22
+ # user_id = state["user_id"]
23
+ # session_id = state["session_id"]
24
+ # messages = state.get("messages", [])
25
+ # original_result = query_message(alayamem_url, user_id, session_id, messages)
26
+ # updated_state = state.copy()
27
+ # updated_state["context"] = original_result.get("context", "")
28
+ # return updated_state
29
+ # return query_message_node
30
+
31
+ def mk_query_vdb_message_node(alayamem_url: str):
32
+ def query_vdb_message_node(state: WorkflowState):
33
+ messages = state.get("messages", [])
34
+ limit = state.get("limit", 5)
35
+ original_result = query_vdb_message(alayamem_url, messages, limit)
36
+ updated_state = state.copy()
37
+ updated_state["retrieved_docs"] = original_result.get("vdb_results", [])
38
+ return updated_state
39
+ return query_vdb_message_node
40
+
41
+ def mk_chat_node():
42
+ model_manager = ModelManager()
43
+
44
+ def chat_node(state: WorkflowState, runtime: Runtime[WorkflowContext]):
45
+ model_id = runtime.context.chat_model_id
46
+ chat_model = model_manager.get_model(model_id)
47
+ if not chat_model:
48
+ raise ValueError(f"无法找到模型ID为 '{model_id}' 的模型配置")
49
+
50
+ messages = state["messages"].copy()
51
+ updated_state = state.copy()
52
+
53
+ retrieved_docs = state.get("retrieved_docs", [])
54
+ if retrieved_docs:
55
+ context_text = "\n\n".join([str(doc) for doc in retrieved_docs])
56
+ from langchain_core.messages import SystemMessage
57
+ context_message = SystemMessage(
58
+ content=f"以下是相关的参考资料,请基于这些资料回答用户的问题:\n\n{context_text}"
59
+ )
60
+ messages.insert(0, context_message)
61
+
62
+ response = chat_model.invoke(messages)
63
+ updated_state['chat_response'] = response
64
+ return updated_state
65
+ return chat_node
66
+
67
+ def mk_add_message_node(alayamem_url: str):
68
+ def add_message_node(state: WorkflowState, runtime: Runtime[WorkflowContext]):
69
+ user_id = runtime.context.user_id
70
+ session_id = runtime.context.session_id
71
+ messages = state.get("messages", [])
72
+ add_message(alayamem_url, user_id, session_id, messages)
73
+ return state.copy()
74
+ return add_message_node
75
+
76
+ def create_graph(init_args: WorkflowInitArgs | dict):
77
+ if isinstance(init_args, dict):
78
+ init_args = WorkflowInitArgs(**init_args)
79
+ alayamem_url = init_args.alayamem_url
80
+
81
+ graph = StateGraph(WorkflowState, WorkflowContext, input_type=Input, output_type=Output)
82
+
83
+ graph.add_node("init_memory_node", mk_init_memory_node(alayamem_url))
84
+ graph.add_node("query_vdb_message_node", mk_query_vdb_message_node(alayamem_url))
85
+ graph.add_node("chat_node", mk_chat_node())
86
+ graph.add_node("add_message_node", mk_add_message_node(alayamem_url))
87
+
88
+ graph.add_edge(START, "init_memory_node")
89
+ graph.add_edge("init_memory_node", "query_vdb_message_node")
90
+ graph.add_edge("query_vdb_message_node", "chat_node")
91
+ graph.add_edge("chat_node", "add_message_node")
92
+ graph.add_edge("add_message_node", END)
93
+
94
+ return graph.compile()
@@ -0,0 +1,41 @@
1
+ name: PR 自动测试
2
+ on:
3
+ pull_request:
4
+ branches: [ main, master ]
5
+ # paths:
6
+ # - 'src/**'
7
+ # - 'tests/**'
8
+ # - 'pyproject.toml'
9
+ # - 'uv.lock'
10
+
11
+ jobs:
12
+ test:
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - name: 拉取代码
16
+ uses: actions/checkout@v4
17
+
18
+ - name: 配置Python版本
19
+ uses: actions/setup-python@v5
20
+ with:
21
+ python-version: "3.12"
22
+ allow-prereleases: false
23
+
24
+ - name: 安装uv
25
+ run: curl -LsSf https://astral.sh/uv/install.sh | sh
26
+
27
+ - name: 缓存uv依赖
28
+ uses: actions/cache@v4
29
+ with:
30
+ path: ~/.cache/uv
31
+ # 缓存key:Python版本 + uv.lock文件(文件变动则重新缓存)
32
+ key: ${{ runner.os }}-python-${{ steps.setup-python.outputs.python-version }}-uv-${{ hashFiles('uv.lock') }}
33
+ # 回退key:匹配同系统、同Python版本的最新缓存
34
+ restore-keys: |
35
+ ${{ runner.os }}-python-${{ steps.setup-python.outputs.python-version }}-uv-
36
+
37
+ - name: 安装项目依赖
38
+ run: uv sync
39
+
40
+ - name: 执行pytest测试
41
+ run: uv run pytest