gohumanloop 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gohumanloop/adapters/__init__.py +17 -0
- gohumanloop/adapters/langgraph_adapter.py +819 -0
- gohumanloop/cli/__init__.py +0 -0
- gohumanloop/cli/main.py +29 -0
- gohumanloop/core/__init__.py +0 -0
- gohumanloop/core/interface.py +437 -0
- gohumanloop/core/manager.py +576 -0
- gohumanloop/manager/__init__.py +0 -0
- gohumanloop/manager/ghl_manager.py +532 -0
- gohumanloop/models/__init__.py +0 -0
- gohumanloop/models/api_model.py +54 -0
- gohumanloop/models/glh_model.py +23 -0
- gohumanloop/providers/__init__.py +0 -0
- gohumanloop/providers/api_provider.py +628 -0
- gohumanloop/providers/base.py +428 -0
- gohumanloop/providers/email_provider.py +1019 -0
- gohumanloop/providers/ghl_provider.py +64 -0
- gohumanloop/providers/terminal_provider.py +302 -0
- gohumanloop/utils/__init__.py +1 -0
- gohumanloop/utils/context_formatter.py +59 -0
- gohumanloop/utils/threadsafedict.py +243 -0
- gohumanloop/utils/utils.py +40 -0
- {gohumanloop-0.0.1.dist-info → gohumanloop-0.0.2.dist-info}/METADATA +2 -1
- gohumanloop-0.0.2.dist-info/RECORD +30 -0
- gohumanloop-0.0.1.dist-info/RECORD +0 -8
- {gohumanloop-0.0.1.dist-info → gohumanloop-0.0.2.dist-info}/WHEEL +0 -0
- {gohumanloop-0.0.1.dist-info → gohumanloop-0.0.2.dist-info}/entry_points.txt +0 -0
- {gohumanloop-0.0.1.dist-info → gohumanloop-0.0.2.dist-info}/licenses/LICENSE +0 -0
- {gohumanloop-0.0.1.dist-info → gohumanloop-0.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Dict, Any, Optional
|
3
|
+
from pydantic import BaseModel, Field, field_validator, SecretStr
|
4
|
+
|
5
|
+
from gohumanloop.models.glh_model import GoHumanLoopConfig
|
6
|
+
from gohumanloop.providers.api_provider import APIProvider
|
7
|
+
from gohumanloop.utils import get_secret_from_env
|
8
|
+
|
9
|
+
class GoHumanLoopProvider(APIProvider):
|
10
|
+
"""
|
11
|
+
GoHumanLoop platform provider class.
|
12
|
+
This class is a concrete implementation of the `APIProvider` class.
|
13
|
+
The `GoHumanLoopProvider` class is responsible for interacting with the GoHumanLoop platform.
|
14
|
+
"""
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
name: str,
|
18
|
+
request_timeout: int = 30,
|
19
|
+
poll_interval: int = 5,
|
20
|
+
max_retries: int = 3,
|
21
|
+
default_platform: Optional[str] = "GoHumanLoop",
|
22
|
+
config: Optional[Dict[str, Any]] = None
|
23
|
+
):
|
24
|
+
"""Initialize GoHumanLoop provider
|
25
|
+
|
26
|
+
Args:
|
27
|
+
name: Provider name
|
28
|
+
api_key: GoHumanLoop API key, if not provided will be fetched from environment variables
|
29
|
+
api_base_url: GoHumanLoop API base URL, if not provided will use default value
|
30
|
+
default_platform: Default platform, e.g. "wechat", "feishu" etc.
|
31
|
+
request_timeout: API request timeout in seconds
|
32
|
+
poll_interval: Polling interval in seconds
|
33
|
+
max_retries: Maximum number of retries
|
34
|
+
config: Additional configuration parameters
|
35
|
+
"""
|
36
|
+
# Get API key from environment variables (if not provided)
|
37
|
+
api_key = get_secret_from_env("GOHUMANLOOP_API_KEY")
|
38
|
+
|
39
|
+
# Get API base URL from environment variables (if not provided)
|
40
|
+
api_base_url = os.environ.get("GOHUMANLOOP_API_BASE_URL", "https://www.gohumanloop.com")
|
41
|
+
|
42
|
+
# Validate configuration using pydantic model
|
43
|
+
ghl_config = GoHumanLoopConfig(
|
44
|
+
api_key=api_key,
|
45
|
+
api_base_url=api_base_url
|
46
|
+
)
|
47
|
+
|
48
|
+
super().__init__(
|
49
|
+
name=name,
|
50
|
+
api_base_url=ghl_config.api_base_url,
|
51
|
+
api_key=ghl_config.api_key,
|
52
|
+
default_platform=default_platform,
|
53
|
+
request_timeout=request_timeout,
|
54
|
+
poll_interval=poll_interval,
|
55
|
+
max_retries=max_retries,
|
56
|
+
config=config
|
57
|
+
)
|
58
|
+
|
59
|
+
def __str__(self) -> str:
|
60
|
+
"""Returns a string description of this instance"""
|
61
|
+
base_str = super().__str__()
|
62
|
+
ghl_info = f"- GoHumanLoop Provider: Connected to GoHumanLoop Official Platform\n"
|
63
|
+
return f"{ghl_info}{base_str}"
|
64
|
+
|
@@ -0,0 +1,302 @@
|
|
1
|
+
import asyncio
|
2
|
+
from email import message
|
3
|
+
import sys
|
4
|
+
import json
|
5
|
+
from typing import Dict, Any, Optional, List
|
6
|
+
from datetime import datetime
|
7
|
+
|
8
|
+
from gohumanloop.core.interface import (HumanLoopResult, HumanLoopStatus, HumanLoopType)
|
9
|
+
from gohumanloop.providers.base import BaseProvider
|
10
|
+
|
11
|
+
class TerminalProvider(BaseProvider):
|
12
|
+
"""Terminal-based human-in-the-loop provider implementation
|
13
|
+
|
14
|
+
This provider interacts with users through command line interface, suitable for testing and simple scenarios.
|
15
|
+
Users can respond to requests via terminal input, supporting approval, information collection and conversation type interactions.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, name: str, config: Optional[Dict[str, Any]] = None):
|
19
|
+
"""Initialize terminal provider
|
20
|
+
|
21
|
+
Args:
|
22
|
+
name: Provider name
|
23
|
+
config: Configuration options, may include:
|
24
|
+
"""
|
25
|
+
super().__init__(name, config)
|
26
|
+
def __str__(self) -> str:
|
27
|
+
base_str = super().__str__()
|
28
|
+
terminal_info = f"- Terminal Provider: Terminal-based human-in-the-loop implementation\n"
|
29
|
+
return f"{terminal_info}{base_str}"
|
30
|
+
|
31
|
+
async def request_humanloop(
|
32
|
+
self,
|
33
|
+
task_id: str,
|
34
|
+
conversation_id: str,
|
35
|
+
loop_type: HumanLoopType,
|
36
|
+
context: Dict[str, Any],
|
37
|
+
metadata: Optional[Dict[str, Any]] = None,
|
38
|
+
timeout: Optional[int] = None
|
39
|
+
) -> HumanLoopResult:
|
40
|
+
"""Request human-in-the-loop interaction through terminal
|
41
|
+
|
42
|
+
Args:
|
43
|
+
task_id: Task identifier
|
44
|
+
conversation_id: Conversation ID for multi-turn dialogs
|
45
|
+
loop_type: Loop type
|
46
|
+
context: Context information provided to human
|
47
|
+
metadata: Additional metadata
|
48
|
+
timeout: Request timeout in seconds
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
HumanLoopResult: Result object containing request ID and initial status
|
52
|
+
"""
|
53
|
+
# Generate request ID
|
54
|
+
request_id = self._generate_request_id()
|
55
|
+
|
56
|
+
# Store request information
|
57
|
+
self._store_request(
|
58
|
+
conversation_id=conversation_id,
|
59
|
+
request_id=request_id,
|
60
|
+
task_id=task_id,
|
61
|
+
loop_type=loop_type,
|
62
|
+
context=context,
|
63
|
+
metadata=metadata or {},
|
64
|
+
timeout=timeout
|
65
|
+
)
|
66
|
+
|
67
|
+
# Create initial result object
|
68
|
+
result = HumanLoopResult(
|
69
|
+
conversation_id=conversation_id,
|
70
|
+
request_id=request_id,
|
71
|
+
loop_type=loop_type,
|
72
|
+
status=HumanLoopStatus.PENDING
|
73
|
+
)
|
74
|
+
|
75
|
+
# Start async task to process user input
|
76
|
+
asyncio.create_task(self._process_terminal_interaction(conversation_id, request_id))
|
77
|
+
|
78
|
+
# Create timeout task if timeout is specified
|
79
|
+
if timeout:
|
80
|
+
self._create_timeout_task(conversation_id, request_id, timeout)
|
81
|
+
|
82
|
+
return result
|
83
|
+
|
84
|
+
async def check_request_status(
|
85
|
+
self,
|
86
|
+
conversation_id: str,
|
87
|
+
request_id: str
|
88
|
+
) -> HumanLoopResult:
|
89
|
+
"""Check request status
|
90
|
+
|
91
|
+
Args:
|
92
|
+
conversation_id: Conversation identifier
|
93
|
+
request_id: Request identifier
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
HumanLoopResult: Result object containing current status
|
97
|
+
"""
|
98
|
+
request_info = self._get_request(conversation_id, request_id)
|
99
|
+
if not request_info:
|
100
|
+
return HumanLoopResult(
|
101
|
+
conversation_id=conversation_id,
|
102
|
+
request_id=request_id,
|
103
|
+
loop_type=HumanLoopType.CONVERSATION,
|
104
|
+
status=HumanLoopStatus.ERROR,
|
105
|
+
error=f"Request '{request_id}' not found in conversation '{conversation_id}'"
|
106
|
+
)
|
107
|
+
|
108
|
+
# Build result object
|
109
|
+
result = HumanLoopResult(
|
110
|
+
conversation_id=conversation_id,
|
111
|
+
request_id=request_id,
|
112
|
+
loop_type=request_info.get("loop_type", HumanLoopType.CONVERSATION),
|
113
|
+
status=request_info.get("status", HumanLoopStatus.PENDING),
|
114
|
+
response=request_info.get("response", {}),
|
115
|
+
feedback=request_info.get("feedback", {}),
|
116
|
+
responded_by=request_info.get("responded_by", None),
|
117
|
+
responded_at=request_info.get("responded_at", None),
|
118
|
+
error=request_info.get("error", None)
|
119
|
+
)
|
120
|
+
|
121
|
+
return result
|
122
|
+
|
123
|
+
async def continue_humanloop(
|
124
|
+
self,
|
125
|
+
conversation_id: str,
|
126
|
+
context: Dict[str, Any],
|
127
|
+
metadata: Optional[Dict[str, Any]] = None,
|
128
|
+
timeout: Optional[int] = None,
|
129
|
+
) -> HumanLoopResult:
|
130
|
+
"""Continue human-in-the-loop interaction for multi-turn conversations
|
131
|
+
|
132
|
+
Args:
|
133
|
+
conversation_id: Conversation identifier
|
134
|
+
context: Context information provided to human
|
135
|
+
metadata: Additional metadata
|
136
|
+
timeout: Request timeout in seconds
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
HumanLoopResult: Result object containing request ID and status
|
140
|
+
"""
|
141
|
+
# Check if conversation exists
|
142
|
+
conversation_info = self._get_conversation(conversation_id)
|
143
|
+
if not conversation_info:
|
144
|
+
return HumanLoopResult(
|
145
|
+
conversation_id=conversation_id,
|
146
|
+
request_id="",
|
147
|
+
loop_type=HumanLoopType.CONVERSATION,
|
148
|
+
status=HumanLoopStatus.ERROR,
|
149
|
+
error=f"Conversation '{conversation_id}' not found"
|
150
|
+
)
|
151
|
+
|
152
|
+
# Generate new request ID
|
153
|
+
request_id = self._generate_request_id()
|
154
|
+
|
155
|
+
# Get task ID
|
156
|
+
task_id = conversation_info.get("task_id", "unknown_task")
|
157
|
+
|
158
|
+
# Store request information
|
159
|
+
self._store_request(
|
160
|
+
conversation_id=conversation_id,
|
161
|
+
request_id=request_id,
|
162
|
+
task_id=task_id,
|
163
|
+
loop_type=HumanLoopType.CONVERSATION, # Default to conversation type for continued dialog
|
164
|
+
context=context,
|
165
|
+
metadata=metadata or {},
|
166
|
+
timeout=timeout
|
167
|
+
)
|
168
|
+
|
169
|
+
# Create initial result object
|
170
|
+
result = HumanLoopResult(
|
171
|
+
conversation_id=conversation_id,
|
172
|
+
request_id=request_id,
|
173
|
+
loop_type=HumanLoopType.CONVERSATION,
|
174
|
+
status=HumanLoopStatus.PENDING
|
175
|
+
)
|
176
|
+
|
177
|
+
# Start async task to process user input
|
178
|
+
asyncio.create_task(self._process_terminal_interaction(conversation_id, request_id))
|
179
|
+
|
180
|
+
# Create timeout task if timeout is specified
|
181
|
+
if timeout:
|
182
|
+
self._create_timeout_task(conversation_id, request_id, timeout)
|
183
|
+
|
184
|
+
return result
|
185
|
+
|
186
|
+
async def _process_terminal_interaction(self, conversation_id: str, request_id: str):
|
187
|
+
request_info = self._get_request(conversation_id, request_id)
|
188
|
+
if not request_info:
|
189
|
+
return
|
190
|
+
|
191
|
+
prompt = self.build_prompt(
|
192
|
+
task_id=request_info["task_id"],
|
193
|
+
conversation_id=conversation_id,
|
194
|
+
request_id=request_id,
|
195
|
+
loop_type=request_info["loop_type"],
|
196
|
+
created_at=request_info.get("created_at", ""),
|
197
|
+
context=request_info["context"],
|
198
|
+
metadata=request_info.get("metadata")
|
199
|
+
)
|
200
|
+
|
201
|
+
loop_type = request_info["loop_type"]
|
202
|
+
|
203
|
+
# Display prompt message
|
204
|
+
print(prompt)
|
205
|
+
|
206
|
+
# Handle different interaction types based on loop type
|
207
|
+
if loop_type == HumanLoopType.APPROVAL:
|
208
|
+
await self._handle_approval_interaction(conversation_id, request_id, request_info)
|
209
|
+
elif loop_type == HumanLoopType.INFORMATION:
|
210
|
+
await self._handle_information_interaction(conversation_id, request_id, request_info)
|
211
|
+
else: # HumanLoopType.CONVERSATION
|
212
|
+
await self._handle_conversation_interaction(conversation_id, request_id, request_info)
|
213
|
+
|
214
|
+
|
215
|
+
async def _handle_approval_interaction(self, conversation_id: str, request_id: str, request_info: Dict[str, Any]):
|
216
|
+
"""Handle approval type interaction
|
217
|
+
|
218
|
+
Args:
|
219
|
+
conversation_id: Conversation ID
|
220
|
+
request_id: Request ID
|
221
|
+
request_info: Request information
|
222
|
+
"""
|
223
|
+
print("\nPlease enter your decision (approve/reject):")
|
224
|
+
|
225
|
+
# Execute blocking input() call in thread pool using run_in_executor
|
226
|
+
loop = asyncio.get_event_loop()
|
227
|
+
response = await loop.run_in_executor(None, input)
|
228
|
+
|
229
|
+
# Process response
|
230
|
+
response = response.strip().lower()
|
231
|
+
if response in ["approve", "yes", "y", "同意", "批准"]:
|
232
|
+
status = HumanLoopStatus.APPROVED
|
233
|
+
response_data = ""
|
234
|
+
elif response in ["reject", "no", "n", "拒绝", "不同意"]:
|
235
|
+
status = HumanLoopStatus.REJECTED
|
236
|
+
print("\nPlease enter the reason for rejection:")
|
237
|
+
reason = await loop.run_in_executor(None, input)
|
238
|
+
response_data = reason
|
239
|
+
else:
|
240
|
+
print("\nInvalid input, please enter 'approve' or 'reject'")
|
241
|
+
# Recursively handle approval interaction
|
242
|
+
await self._handle_approval_interaction(conversation_id, request_id, request_info)
|
243
|
+
return
|
244
|
+
|
245
|
+
# Update request information
|
246
|
+
request_info["status"] = status
|
247
|
+
request_info["response"] = response_data
|
248
|
+
request_info["responded_by"] = "terminal_user"
|
249
|
+
request_info["responded_at"] = datetime.now().isoformat()
|
250
|
+
|
251
|
+
print(f"\nYour decision has been recorded: {status.value}")
|
252
|
+
|
253
|
+
async def _handle_information_interaction(self, conversation_id: str, request_id: str, request_info: Dict[str, Any]):
|
254
|
+
"""Handle information collection type interaction
|
255
|
+
|
256
|
+
Args:
|
257
|
+
conversation_id: Conversation ID
|
258
|
+
request_id: Request ID
|
259
|
+
request_info: Request information
|
260
|
+
"""
|
261
|
+
print("\nPlease provide the required information:")
|
262
|
+
|
263
|
+
# Execute blocking input() call in thread pool using run_in_executor
|
264
|
+
loop = asyncio.get_event_loop()
|
265
|
+
response = await loop.run_in_executor(None, input)
|
266
|
+
|
267
|
+
# Update request information
|
268
|
+
request_info["status"] = HumanLoopStatus.COMPLETED
|
269
|
+
request_info["response"] = response
|
270
|
+
request_info["responded_by"] = "terminal_user"
|
271
|
+
request_info["responded_at"] = datetime.now().isoformat()
|
272
|
+
|
273
|
+
print("\nYour information has been recorded")
|
274
|
+
|
275
|
+
async def _handle_conversation_interaction(self, conversation_id: str, request_id: str, request_info: Dict[str, Any]):
|
276
|
+
"""Handle conversation type interaction
|
277
|
+
|
278
|
+
Args:
|
279
|
+
conversation_id: Conversation ID
|
280
|
+
request_id: Request ID
|
281
|
+
request_info: Request information
|
282
|
+
"""
|
283
|
+
print("\nPlease enter your response (type 'exit' to end conversation):")
|
284
|
+
|
285
|
+
# Execute blocking input() call in thread pool using run_in_executor
|
286
|
+
loop = asyncio.get_event_loop()
|
287
|
+
response = await loop.run_in_executor(None, input)
|
288
|
+
|
289
|
+
# Process response
|
290
|
+
if response.strip().lower() in ["exit", "quit", "结束", "退出"]:
|
291
|
+
status = HumanLoopStatus.COMPLETED
|
292
|
+
print("\nConversation ended")
|
293
|
+
else:
|
294
|
+
status = HumanLoopStatus.INPROGRESS
|
295
|
+
|
296
|
+
# Update request information
|
297
|
+
request_info["status"] = status
|
298
|
+
request_info["response"] = response
|
299
|
+
request_info["responded_by"] = "terminal_user"
|
300
|
+
request_info["responded_at"] = datetime.now().isoformat()
|
301
|
+
|
302
|
+
print("\nYour response has been recorded")
|
@@ -0,0 +1 @@
|
|
1
|
+
from .utils import run_async_safely, get_secret_from_env
|
@@ -0,0 +1,59 @@
|
|
1
|
+
from typing import Dict, Any, List, Optional, Union
|
2
|
+
import json
|
3
|
+
|
4
|
+
class ContextFormatter:
|
5
|
+
"""上下文格式化工具"""
|
6
|
+
|
7
|
+
@staticmethod
|
8
|
+
def format_for_human(context: Dict[str, Any]) -> str:
|
9
|
+
"""将上下文格式化为人类可读的文本"""
|
10
|
+
result = []
|
11
|
+
|
12
|
+
# 添加标题(如果有)
|
13
|
+
if "title" in context:
|
14
|
+
result.append(f"# {context['title']}\n")
|
15
|
+
|
16
|
+
# 添加描述(如果有)
|
17
|
+
if "description" in context:
|
18
|
+
result.append(f"{context['description']}\n")
|
19
|
+
|
20
|
+
# 添加任务信息
|
21
|
+
if "task" in context:
|
22
|
+
result.append(f"## 任务\n{context['task']}\n")
|
23
|
+
|
24
|
+
# 添加代理信息
|
25
|
+
if "agent" in context:
|
26
|
+
result.append(f"## 代理\n{context['agent']}\n")
|
27
|
+
|
28
|
+
# 添加操作信息
|
29
|
+
if "action" in context:
|
30
|
+
result.append(f"## 请求的操作\n{context['action']}\n")
|
31
|
+
|
32
|
+
# 添加原因
|
33
|
+
if "reason" in context:
|
34
|
+
result.append(f"## 原因\n{context['reason']}\n")
|
35
|
+
|
36
|
+
# 添加其他键值对
|
37
|
+
other_keys = [k for k in context.keys() if k not in ["title", "description", "task", "agent", "action", "reason"]]
|
38
|
+
if other_keys:
|
39
|
+
result.append("## 附加信息\n")
|
40
|
+
for key in other_keys:
|
41
|
+
value = context[key]
|
42
|
+
if isinstance(value, (dict, list)):
|
43
|
+
value = json.dumps(value, ensure_ascii=False, indent=2)
|
44
|
+
result.append(f"### {key}\n```\n{value}\n```\n")
|
45
|
+
|
46
|
+
return "\n".join(result)
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def format_for_api(context: Dict[str, Any]) -> Dict[str, Any]:
|
50
|
+
"""将上下文格式化为API友好的格式"""
|
51
|
+
# 复制上下文以避免修改原始数据
|
52
|
+
formatted = context.copy()
|
53
|
+
|
54
|
+
# 确保所有值都是可序列化的
|
55
|
+
for key, value in formatted.items():
|
56
|
+
if not isinstance(value, (str, int, float, bool, list, dict, type(None))):
|
57
|
+
formatted[key] = str(value)
|
58
|
+
|
59
|
+
return formatted
|
@@ -0,0 +1,243 @@
|
|
1
|
+
from typing import Dict, Any, Optional, TypeVar, Generic
|
2
|
+
import asyncio
|
3
|
+
import threading
|
4
|
+
|
5
|
+
|
6
|
+
K = TypeVar('K')
|
7
|
+
V = TypeVar('V')
|
8
|
+
|
9
|
+
class ThreadSafeDict(Generic[K, V]):
|
10
|
+
"""
|
11
|
+
线程安全的字典,同时支持同步和异步操作,两种方式都保证线程安全
|
12
|
+
- 双重锁机制 :
|
13
|
+
|
14
|
+
- 使用 threading.RLock 保护同步操作
|
15
|
+
- 使用 asyncio.Lock 保护异步操作
|
16
|
+
- 使用键级别锁减少锁竞争
|
17
|
+
- 同步方法的线程安全 :
|
18
|
+
|
19
|
+
- 所有同步方法都使用 self._sync_lock 保护
|
20
|
+
- 返回容器类型的方法(如 keys() 、 values() 、 items() )返回列表副本而非迭代器,避免迭代过程中的并发修改问题
|
21
|
+
- 异步方法的线程安全 :
|
22
|
+
|
23
|
+
- 所有异步方法都使用 self._async_lock 保护
|
24
|
+
- 对于写操作,还使用键级别锁进一步减少锁竞争
|
25
|
+
- 键级别锁 :
|
26
|
+
|
27
|
+
- 为每个键创建独立的锁,减少不同键之间的锁竞争
|
28
|
+
- 使用全局锁 self._global_lock 保护键级别锁的创建和删除
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self):
|
32
|
+
self._dict = {}
|
33
|
+
# 使用 threading.RLock 支持同步操作的线程安全
|
34
|
+
self._sync_lock = threading.RLock()
|
35
|
+
# 使用 asyncio.Lock 支持异步操作的线程安全
|
36
|
+
self._async_lock = asyncio.Lock()
|
37
|
+
# 键级别锁字典
|
38
|
+
self._key_locks = {}
|
39
|
+
# 键级别锁的全局锁
|
40
|
+
self._global_lock = asyncio.Lock()
|
41
|
+
|
42
|
+
async def _get_key_lock(self, key):
|
43
|
+
"""获取指定键的锁,如果不存在则创建"""
|
44
|
+
async with self._global_lock:
|
45
|
+
if key not in self._key_locks:
|
46
|
+
self._key_locks[key] = asyncio.Lock()
|
47
|
+
return self._key_locks[key]
|
48
|
+
|
49
|
+
# 同步方法 - 使用 threading.RLock 保证线程安全
|
50
|
+
def __getitem__(self, key: K) -> V:
|
51
|
+
"""获取值 - 同步方法,用于 dict[key] 语法"""
|
52
|
+
with self._sync_lock:
|
53
|
+
return self._dict[key]
|
54
|
+
|
55
|
+
def __setitem__(self, key: K, value: V) -> None:
|
56
|
+
"""设置值 - 同步方法,用于 dict[key] = value 语法"""
|
57
|
+
with self._sync_lock:
|
58
|
+
self._dict[key] = value
|
59
|
+
|
60
|
+
def __delitem__(self, key: K) -> None:
|
61
|
+
"""删除键 - 同步方法,用于 del dict[key] 语法"""
|
62
|
+
with self._sync_lock:
|
63
|
+
del self._dict[key]
|
64
|
+
|
65
|
+
def __contains__(self, key: K) -> bool:
|
66
|
+
"""检查键是否存在 - 同步方法,用于 key in dict 语法"""
|
67
|
+
with self._sync_lock:
|
68
|
+
return key in self._dict
|
69
|
+
|
70
|
+
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
71
|
+
"""获取值,如果不存在则返回默认值 - 同步方法"""
|
72
|
+
with self._sync_lock:
|
73
|
+
return self._dict.get(key, default)
|
74
|
+
|
75
|
+
def __len__(self) -> int:
|
76
|
+
"""获取字典长度 - 同步方法,用于 len(dict) 语法"""
|
77
|
+
with self._sync_lock:
|
78
|
+
return len(self._dict)
|
79
|
+
|
80
|
+
def keys(self):
|
81
|
+
"""获取所有键 - 同步方法"""
|
82
|
+
with self._sync_lock:
|
83
|
+
return list(self._dict.keys())
|
84
|
+
|
85
|
+
def values(self):
|
86
|
+
"""获取所有值 - 同步方法"""
|
87
|
+
with self._sync_lock:
|
88
|
+
return list(self._dict.values())
|
89
|
+
|
90
|
+
def items(self):
|
91
|
+
"""获取所有键值对 - 同步方法"""
|
92
|
+
with self._sync_lock:
|
93
|
+
return list(self._dict.items())
|
94
|
+
|
95
|
+
def update(self, key: K, updates: Dict[str, Any]) -> bool:
|
96
|
+
"""更新字典中的值 - 同步方法"""
|
97
|
+
with self._sync_lock:
|
98
|
+
if key in self._dict and isinstance(self._dict[key], dict):
|
99
|
+
self._dict[key].update(updates)
|
100
|
+
return True
|
101
|
+
return False
|
102
|
+
|
103
|
+
def update_item(self, key: K, item_key: Any, item_value: Any) -> bool:
|
104
|
+
"""更新字典中的单个项 - 同步方法"""
|
105
|
+
with self._sync_lock:
|
106
|
+
if key in self._dict and isinstance(self._dict[key], dict):
|
107
|
+
self._dict[key][item_key] = item_value
|
108
|
+
return True
|
109
|
+
return False
|
110
|
+
|
111
|
+
# 异步方法 - 使用 asyncio.Lock 保证线程安全
|
112
|
+
async def aget(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
113
|
+
"""安全地获取值 - 异步方法"""
|
114
|
+
async with self._async_lock:
|
115
|
+
return self._dict.get(key, default)
|
116
|
+
|
117
|
+
async def aset(self, key: K, value: V) -> None:
|
118
|
+
"""安全地设置值 - 异步方法"""
|
119
|
+
key_lock = await self._get_key_lock(key)
|
120
|
+
async with key_lock:
|
121
|
+
async with self._async_lock:
|
122
|
+
self._dict[key] = value
|
123
|
+
|
124
|
+
async def adelete(self, key: K) -> bool:
|
125
|
+
"""安全地删除键 - 异步方法"""
|
126
|
+
key_lock = await self._get_key_lock(key)
|
127
|
+
async with key_lock:
|
128
|
+
async with self._async_lock:
|
129
|
+
if key in self._dict:
|
130
|
+
del self._dict[key]
|
131
|
+
# 也可以选择删除锁
|
132
|
+
async with self._global_lock:
|
133
|
+
if key in self._key_locks:
|
134
|
+
del self._key_locks[key]
|
135
|
+
return True
|
136
|
+
return False
|
137
|
+
|
138
|
+
async def aupdate(self, key: K, updates: Dict[str, Any]) -> bool:
|
139
|
+
"""安全地更新值 - 异步方法"""
|
140
|
+
key_lock = await self._get_key_lock(key)
|
141
|
+
async with key_lock:
|
142
|
+
async with self._async_lock:
|
143
|
+
if key in self._dict and isinstance(self._dict[key], dict):
|
144
|
+
self._dict[key].update(updates)
|
145
|
+
return True
|
146
|
+
return False
|
147
|
+
|
148
|
+
async def aupdate_item(self, key: K, item_key: Any, item_value: Any) -> bool:
|
149
|
+
"""安全地更新字典中的单个项 - 异步方法"""
|
150
|
+
key_lock = await self._get_key_lock(key)
|
151
|
+
async with key_lock:
|
152
|
+
async with self._async_lock:
|
153
|
+
if key in self._dict and isinstance(self._dict[key], dict):
|
154
|
+
self._dict[key][item_key] = item_value
|
155
|
+
return True
|
156
|
+
return False
|
157
|
+
|
158
|
+
async def acontains(self, key: K) -> bool:
|
159
|
+
"""安全地检查键是否存在 - 异步方法"""
|
160
|
+
async with self._async_lock:
|
161
|
+
return key in self._dict
|
162
|
+
|
163
|
+
async def alen(self) -> int:
|
164
|
+
"""安全地获取字典长度 - 异步方法"""
|
165
|
+
async with self._async_lock:
|
166
|
+
return len(self._dict)
|
167
|
+
|
168
|
+
async def akeys(self):
|
169
|
+
"""安全地获取所有键 - 异步方法"""
|
170
|
+
async with self._async_lock:
|
171
|
+
return list(self._dict.keys())
|
172
|
+
|
173
|
+
async def avalues(self):
|
174
|
+
"""安全地获取所有值 - 异步方法"""
|
175
|
+
async with self._async_lock:
|
176
|
+
return list(self._dict.values())
|
177
|
+
|
178
|
+
async def aitems(self):
|
179
|
+
"""安全地获取所有键值对 - 异步方法"""
|
180
|
+
async with self._async_lock:
|
181
|
+
return list(self._dict.items())
|
182
|
+
|
183
|
+
if __name__ == "__main__":
|
184
|
+
# 测试同步方法
|
185
|
+
def test_sync_methods():
|
186
|
+
print("\n=== 测试同步方法 ===")
|
187
|
+
sync_dict = ThreadSafeDict()
|
188
|
+
|
189
|
+
# 测试基本的增删改查操作
|
190
|
+
sync_dict["key1"] = "value1"
|
191
|
+
print("设置并获取:", sync_dict["key1"]) # value1
|
192
|
+
print("键存在性检查:", "key1" in sync_dict) # True
|
193
|
+
print("获取默认值:", sync_dict.get("not_exist", "default")) # default
|
194
|
+
|
195
|
+
# 测试字典长度
|
196
|
+
print("字典长度:", len(sync_dict)) # 1
|
197
|
+
|
198
|
+
# 测试字典方法
|
199
|
+
sync_dict["key2"] = "value2"
|
200
|
+
print("所有键:", sync_dict.keys()) # ['key1', 'key2']
|
201
|
+
print("所有值:", sync_dict.values()) # ['value1', 'value2']
|
202
|
+
print("所有键值对:", sync_dict.items()) # [('key1', 'value1'), ('key2', 'value2')]
|
203
|
+
|
204
|
+
# 测试嵌套字典更新
|
205
|
+
sync_dict["nested"] = {"a": 1}
|
206
|
+
sync_dict.update("nested", {"b": 2})
|
207
|
+
sync_dict.update_item("nested", "c", 3)
|
208
|
+
print("嵌套字典:", sync_dict["nested"]) # {'a': 1, 'b': 2, 'c': 3}
|
209
|
+
|
210
|
+
# 测试删除操作
|
211
|
+
del sync_dict["key1"]
|
212
|
+
print("删除后检查:", "key1" in sync_dict) # False
|
213
|
+
|
214
|
+
# 测试异步方法
|
215
|
+
async def test_async_methods():
|
216
|
+
print("\n=== 测试异步方法 ===")
|
217
|
+
async_dict = ThreadSafeDict()
|
218
|
+
|
219
|
+
# 测试基本的异步增删改查
|
220
|
+
await async_dict.aset("key1", "value1")
|
221
|
+
print("异步获取:", await async_dict.aget("key1")) # value1
|
222
|
+
print("异步键检查:", await async_dict.acontains("key1")) # True
|
223
|
+
|
224
|
+
# 测试异步字典操作
|
225
|
+
await async_dict.aset("key2", "value2")
|
226
|
+
print("异步长度:", await async_dict.alen()) # 2
|
227
|
+
print("异步所有键:", await async_dict.akeys()) # ['key1', 'key2']
|
228
|
+
print("异步所有值:", await async_dict.avalues()) # ['value1', 'value2']
|
229
|
+
print("异步所有键值对:", await async_dict.aitems()) # [('key1', 'value1'), ('key2', 'value2')]
|
230
|
+
|
231
|
+
# 测试异步嵌套字典更新
|
232
|
+
await async_dict.aset("nested", {"x": 1})
|
233
|
+
await async_dict.aupdate("nested", {"y": 2})
|
234
|
+
await async_dict.aupdate_item("nested", "z", 3)
|
235
|
+
print("异步嵌套字典:", await async_dict.aget("nested")) # {'x': 1, 'y': 2, 'z': 3}
|
236
|
+
|
237
|
+
# 测试异步删除
|
238
|
+
await async_dict.adelete("key1")
|
239
|
+
print("异步删除后检查:", await async_dict.acontains("key1")) # False
|
240
|
+
|
241
|
+
# 运行测试
|
242
|
+
test_sync_methods()
|
243
|
+
asyncio.run(test_async_methods())
|