gohumanloop 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,628 @@
1
+ import asyncio
2
+ import logging
3
+ from typing import Dict, Any, Optional
4
+
5
+ import aiohttp
6
+ from pydantic import SecretStr
7
+
8
+ from gohumanloop.core.interface import (
9
+ HumanLoopResult, HumanLoopStatus, HumanLoopType
10
+ )
11
+ from gohumanloop.providers.base import BaseProvider
12
+ from gohumanloop.models.api_model import (
13
+ APIResponse, HumanLoopRequestData, HumanLoopStatusParams, HumanLoopStatusResponse,
14
+ HumanLoopCancelData, HumanLoopCancelConversationData, HumanLoopContinueData
15
+ )
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class APIProvider(BaseProvider):
20
+ """API-based human-in-the-loop provider that supports integration with third-party service platforms
21
+
22
+ This provider communicates with a central service platform via HTTP requests, where the service platform
23
+ handles specific third-party service integrations (such as WeChat, Feishu, DingTalk, etc.).
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ name: str,
29
+ api_base_url: str,
30
+ api_key: Optional[SecretStr] = None,
31
+ default_platform: Optional[str] = None,
32
+ request_timeout: int = 30,
33
+ poll_interval: int = 5,
34
+ max_retries: int = 3,
35
+ config: Optional[Dict[str, Any]] = None
36
+ ):
37
+ """Initialize API provider
38
+
39
+ Args:
40
+ name: Provider name
41
+ api_base_url: Base URL for API service
42
+ api_key: API authentication key (optional)
43
+ default_platform: Default platform to use (e.g. "wechat", "feishu")
44
+ request_timeout: API request timeout in seconds
45
+ poll_interval: Polling interval in seconds
46
+ max_retries: Maximum number of retry attempts
47
+ config: Additional configuration parameters
48
+ """
49
+ super().__init__(name, config)
50
+ self.api_base_url = api_base_url.rstrip('/')
51
+ self.api_key = api_key
52
+ self.default_platform = default_platform
53
+ self.request_timeout = request_timeout
54
+ self.poll_interval = poll_interval
55
+ self.max_retries = max_retries
56
+
57
+ # Store the currently running polling tasks.
58
+ self._poll_tasks = {}
59
+
60
+ def __str__(self) -> str:
61
+ """Returns a string description of this instance"""
62
+ base_str = super().__str__()
63
+ api_info = f"- API Provider: API-based human-in-the-loop implementation, connected to {self.api_base_url}\n"
64
+ if self.default_platform:
65
+ api_info += f" Default Platform: {self.default_platform}\n"
66
+ return f"{api_info}{base_str}"
67
+
68
+ async def _make_api_request(
69
+ self,
70
+ endpoint: str,
71
+ method: str = "POST",
72
+ data: Optional[Dict[str, Any]] = None,
73
+ params: Optional[Dict[str, Any]] = None,
74
+ headers: Optional[Dict[str, Any]] = None
75
+ ) -> Optional[Dict[str, Any]]:
76
+ """Make API request
77
+
78
+ Args:
79
+ endpoint: API endpoint path
80
+ method: HTTP method (GET, POST, etc.)
81
+ data: Request body data
82
+ params: URL query parameters
83
+ headers: Request headers
84
+
85
+ Returns:
86
+ Dict[str, Any]: API response data
87
+
88
+ Raises:
89
+ Exception: If API request fails
90
+ """
91
+ url = f"{self.api_base_url}/{endpoint.lstrip('/')}"
92
+
93
+ # Prepare request headers
94
+ request_headers = {
95
+ "Content-Type": "application/json",
96
+ }
97
+ # Add authentication information
98
+ if self.api_key:
99
+ request_headers["Authorization"] = f"Bearer {self.api_key.get_secret_value()}"
100
+
101
+ # Merge custom headers
102
+ if headers:
103
+ request_headers.update(headers)
104
+
105
+ # Prepare request data
106
+ json_data = None
107
+ if data:
108
+ json_data = data
109
+
110
+ # Send request
111
+ for attempt in range(self.max_retries):
112
+ try:
113
+ async with aiohttp.ClientSession() as session:
114
+ async with session.request(
115
+ method=method,
116
+ url=url,
117
+ json=json_data,
118
+ params=params,
119
+ headers=request_headers,
120
+ timeout=self.request_timeout
121
+ ) as response:
122
+ response_data = await response.json()
123
+ # Check response status
124
+ if response.status >= 400:
125
+ error_msg = response_data.get("error", f"API request failed: {response.status}")
126
+ logger.error(f"API request failed: {error_msg}")
127
+
128
+ # Retry if not the last attempt
129
+ if attempt < self.max_retries - 1:
130
+ await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff
131
+ continue
132
+
133
+ raise Exception(error_msg)
134
+
135
+ return response_data
136
+ except asyncio.TimeoutError:
137
+ logger.warning(f"API request timeout (attempt {attempt+1}/{self.max_retries})")
138
+ if attempt < self.max_retries - 1:
139
+ await asyncio.sleep(1 * (attempt + 1))
140
+ continue
141
+ raise Exception("API request timeout")
142
+ except Exception as e:
143
+ logger.error(f"API request error: {str(e)}")
144
+ if attempt < self.max_retries - 1:
145
+ await asyncio.sleep(1 * (attempt + 1))
146
+ continue
147
+ raise
148
+
149
+ async def request_humanloop(
150
+ self,
151
+ task_id: str,
152
+ conversation_id: str,
153
+ loop_type: HumanLoopType,
154
+ context: Dict[str, Any],
155
+ metadata: Optional[Dict[str, Any]] = None,
156
+ timeout: Optional[int] = None
157
+ ) -> HumanLoopResult:
158
+ """Request human-in-the-loop interaction
159
+
160
+ Args:
161
+ task_id: Task identifier
162
+ conversation_id: Conversation ID for multi-turn dialogue
163
+ loop_type: Type of loop interaction
164
+ context: Context information provided to humans
165
+ metadata: Additional metadata
166
+ timeout: Request timeout in seconds
167
+
168
+ Returns:
169
+ HumanLoopResult: Result object containing request ID and initial status
170
+ """
171
+ metadata = metadata or {}
172
+
173
+ # Generate request ID
174
+ request_id = self._generate_request_id()
175
+ platform = metadata.get("platform", self.default_platform)
176
+ # Store request information
177
+ self._store_request(
178
+ conversation_id=conversation_id,
179
+ request_id=request_id,
180
+ task_id=task_id,
181
+ loop_type=loop_type,
182
+ context=context,
183
+ metadata={**metadata, "platform": platform},
184
+ timeout=timeout
185
+ )
186
+
187
+ # Determine which platform to use
188
+ if not platform:
189
+ self._update_request_status_error(conversation_id, request_id, "Platform not specified. Please set 'platform' in metadata or set default_platform during initialization")
190
+ return HumanLoopResult(
191
+ conversation_id=conversation_id,
192
+ request_id=request_id,
193
+ loop_type=loop_type,
194
+ status=HumanLoopStatus.ERROR,
195
+ error="Platform not specified. Please set 'platform' in metadata or set default_platform during initialization"
196
+ )
197
+
198
+ # Prepare API request data
199
+ request_data = HumanLoopRequestData(
200
+ task_id=task_id,
201
+ conversation_id=conversation_id,
202
+ request_id=request_id,
203
+ loop_type=loop_type.value,
204
+ context=context,
205
+ platform=platform,
206
+ metadata=metadata
207
+ ).model_dump()
208
+
209
+ try:
210
+ # Send API request
211
+ response = await self._make_api_request(
212
+ endpoint="v1/humanloop/request",
213
+ method="POST",
214
+ data=request_data
215
+ )
216
+
217
+ # Check API response
218
+ api_response = APIResponse(**response)
219
+ if not api_response.success:
220
+ error_msg = api_response.error or "API request failed without error message"
221
+ # Update request status to error
222
+ self._update_request_status_error(conversation_id, request_id, error_msg)
223
+
224
+ return HumanLoopResult(
225
+ conversation_id=conversation_id,
226
+ request_id=request_id,
227
+ loop_type=loop_type,
228
+ status=HumanLoopStatus.ERROR,
229
+ error=error_msg
230
+ )
231
+
232
+ # Create polling task
233
+ poll_task = asyncio.create_task(
234
+ self._poll_request_status(conversation_id, request_id, platform)
235
+ )
236
+ self._poll_tasks[(conversation_id, request_id)] = poll_task
237
+
238
+ # Create timeout task if timeout is set
239
+ if timeout:
240
+ self._create_timeout_task(conversation_id, request_id, timeout)
241
+
242
+ return HumanLoopResult(
243
+ conversation_id=conversation_id,
244
+ request_id=request_id,
245
+ loop_type=loop_type,
246
+ status=HumanLoopStatus.PENDING
247
+ )
248
+
249
+ except Exception as e:
250
+ logger.error(f"Failed to request human-in-the-loop: {str(e)}")
251
+ # Update request status to error
252
+ self._update_request_status_error(conversation_id, request_id, str(e))
253
+
254
+ return HumanLoopResult(
255
+ conversation_id=conversation_id,
256
+ request_id=request_id,
257
+ loop_type=loop_type,
258
+ status=HumanLoopStatus.ERROR,
259
+ error=str(e)
260
+ )
261
+ async def check_request_status(
262
+ self,
263
+ conversation_id: str,
264
+ request_id: str
265
+ ) -> HumanLoopResult:
266
+ """Check request status
267
+
268
+ Args:
269
+ conversation_id: Conversation identifier
270
+ request_id: Request identifier
271
+
272
+ Returns:
273
+ HumanLoopResult: Result object containing current status
274
+ """
275
+ request_info = self._get_request(conversation_id, request_id)
276
+ if not request_info:
277
+ return HumanLoopResult(
278
+ conversation_id=conversation_id,
279
+ request_id=request_id,
280
+ loop_type=HumanLoopType.CONVERSATION,
281
+ status=HumanLoopStatus.ERROR,
282
+ error=f"Request '{request_id}' not found in conversation '{conversation_id}'"
283
+ )
284
+
285
+ result = HumanLoopResult(
286
+ conversation_id=conversation_id,
287
+ request_id=request_id,
288
+ loop_type=request_info.get("loop_type", HumanLoopType.CONVERSATION),
289
+ status=request_info.get("status", HumanLoopStatus.PENDING),
290
+ response=request_info.get("response", {}),
291
+ feedback=request_info.get("feedback", {}),
292
+ responded_by=request_info.get("responded_by", None),
293
+ responded_at=request_info.get("responded_at", None),
294
+ error=request_info.get("error", None)
295
+ )
296
+
297
+ return result
298
+
299
+
300
+ async def cancel_request(
301
+ self,
302
+ conversation_id: str,
303
+ request_id: str
304
+ ) -> bool:
305
+ """Cancel human-in-the-loop request
306
+
307
+ Args:
308
+ conversation_id: Conversation identifier for multi-turn dialogue
309
+ request_id: Request identifier for specific interaction request
310
+
311
+ Returns:
312
+ bool: Whether cancellation was successful, True for success, False for failure
313
+ """
314
+ # First call parent method to update local state
315
+ result = await super().cancel_request(conversation_id, request_id)
316
+ if not result:
317
+ return False
318
+
319
+ # Get request information
320
+ request_info = self._get_request(conversation_id, request_id)
321
+ if not request_info:
322
+ return False
323
+
324
+ # Get platform information
325
+ platform = request_info.get("metadata", {}).get("platform")
326
+ if not platform:
327
+ logger.error(f"Cancel request failed: Platform information not found")
328
+ return False
329
+
330
+ try:
331
+ # Send API request to cancel request
332
+ cancel_data = HumanLoopCancelData(
333
+ conversation_id=conversation_id,
334
+ request_id=request_id,
335
+ platform=platform
336
+ ).model_dump()
337
+
338
+ response = await self._make_api_request(
339
+ endpoint="v1/humanloop/cancel",
340
+ method="POST",
341
+ data=cancel_data
342
+ )
343
+
344
+ # Check API response
345
+ api_response = APIResponse(**response)
346
+ if not api_response.success:
347
+ error_msg = api_response.error or "Cancel request failed without error message"
348
+ logger.error(f"Cancel request failed: {error_msg}")
349
+ return False
350
+
351
+ # Cancel polling task
352
+ if (conversation_id, request_id) in self._poll_tasks:
353
+ self._poll_tasks[(conversation_id, request_id)].cancel()
354
+ del self._poll_tasks[(conversation_id, request_id)]
355
+
356
+ return True
357
+
358
+ except Exception as e:
359
+ logger.error(f"Cancel request failed: {str(e)}")
360
+ return False
361
+
362
+ async def cancel_conversation(
363
+ self,
364
+ conversation_id: str
365
+ ) -> bool:
366
+ """Cancel entire conversation
367
+
368
+ Args:
369
+ conversation_id: Conversation identifier
370
+
371
+ Returns:
372
+ bool: Whether cancellation was successful
373
+ """
374
+ # First call parent method to update local state
375
+ result = await super().cancel_conversation(conversation_id)
376
+ if not result:
377
+ return False
378
+
379
+ # Get all requests in the conversation
380
+ request_ids = self._get_conversation_requests(conversation_id)
381
+ if not request_ids:
382
+ return True # No requests to cancel
383
+
384
+ # Get platform info from first request (assuming all requests use same platform)
385
+ first_request = self._get_request(conversation_id, request_ids[0])
386
+ if not first_request:
387
+ return False
388
+
389
+ platform = first_request.get("metadata", {}).get("platform")
390
+ if not platform:
391
+ logger.error(f"Cancel conversation failed: Platform information not found")
392
+ return False
393
+
394
+ try:
395
+ # Send API request to cancel conversation
396
+ cancel_data = HumanLoopCancelConversationData(
397
+ conversation_id=conversation_id,
398
+ platform=platform
399
+ ).model_dump()
400
+
401
+ response = await self._make_api_request(
402
+ endpoint="v1/humanloop/cancel_conversation",
403
+ method="POST",
404
+ data=cancel_data
405
+ )
406
+
407
+ # Check API response
408
+ api_response = APIResponse(**response)
409
+ if not api_response.success:
410
+ error_msg = api_response.error or "Cancel conversation failed without error message"
411
+ logger.error(f"Cancel conversation failed: {error_msg}")
412
+ return False
413
+
414
+ # Cancel all polling tasks
415
+ for request_id in request_ids:
416
+ if (conversation_id, request_id) in self._poll_tasks:
417
+ self._poll_tasks[(conversation_id, request_id)].cancel()
418
+ del self._poll_tasks[(conversation_id, request_id)]
419
+
420
+ return True
421
+
422
+ except Exception as e:
423
+ logger.error(f"Cancel conversation failed: {str(e)}")
424
+ return False
425
+
426
+
427
+ async def continue_humanloop(
428
+ self,
429
+ conversation_id: str,
430
+ context: Dict[str, Any],
431
+ metadata: Optional[Dict[str, Any]] = None,
432
+ timeout: Optional[int] = None,
433
+ ) -> HumanLoopResult:
434
+ """Continue human-in-the-loop interaction
435
+
436
+ Args:
437
+ conversation_id: Conversation ID for multi-turn dialogue
438
+ context: Context information provided to humans
439
+ metadata: Additional metadata
440
+ timeout: Request timeout in seconds
441
+
442
+ Returns:
443
+ HumanLoopResult: Result object containing request ID and status
444
+ """
445
+ # 检查对话是否存在
446
+ conversation_info = self._get_conversation(conversation_id)
447
+ if not conversation_info:
448
+ return HumanLoopResult(
449
+ conversation_id=conversation_id,
450
+ request_id="",
451
+ loop_type=HumanLoopType.CONVERSATION,
452
+ status=HumanLoopStatus.ERROR,
453
+ error=f"Conversation '{conversation_id}' not found"
454
+ )
455
+
456
+ metadata = metadata or {}
457
+
458
+ # Generate request ID
459
+ request_id = self._generate_request_id()
460
+
461
+ # Get task ID
462
+ task_id = conversation_info.get("task_id", "unknown_task")
463
+ # Determine which platform to use
464
+ platform = metadata.get("platform", self.default_platform)
465
+
466
+ # Store request information
467
+ self._store_request(
468
+ conversation_id=conversation_id,
469
+ request_id=request_id,
470
+ task_id=task_id,
471
+ loop_type=HumanLoopType.CONVERSATION,
472
+ context=context,
473
+ metadata={**metadata, "platform": platform},
474
+ timeout=timeout
475
+ )
476
+
477
+ if not platform:
478
+ self._update_request_status_error(conversation_id, request_id, "Platform not specified. Please set 'platform' in metadata or set default_platform during initialization")
479
+ return HumanLoopResult(
480
+ conversation_id=conversation_id,
481
+ request_id=request_id,
482
+ loop_type=HumanLoopType.CONVERSATION,
483
+ status=HumanLoopStatus.ERROR,
484
+ error="Platform not specified. Please set 'platform' in metadata or set default_platform during initialization"
485
+ )
486
+
487
+ # Prepare API request data
488
+ continue_data = HumanLoopContinueData(
489
+ conversation_id=conversation_id,
490
+ request_id=request_id,
491
+ task_id=task_id,
492
+ context=context,
493
+ platform=platform,
494
+ metadata=metadata
495
+ ).model_dump()
496
+
497
+ try:
498
+ # Send API request
499
+ response = await self._make_api_request(
500
+ endpoint="v1/humanloop/continue",
501
+ method="POST",
502
+ data=continue_data
503
+ )
504
+
505
+ # Check API response
506
+ api_response = APIResponse(**response)
507
+ if not api_response.success:
508
+ error_msg = api_response.error or "Continue conversation failed without error message"
509
+
510
+ self._update_request_status_error(conversation_id, request_id, error_msg)
511
+ return HumanLoopResult(
512
+ conversation_id=conversation_id,
513
+ request_id=request_id,
514
+ loop_type=HumanLoopType.CONVERSATION,
515
+ status=HumanLoopStatus.ERROR,
516
+ error=error_msg
517
+ )
518
+
519
+ # Create polling task
520
+ poll_task = asyncio.create_task(
521
+ self._poll_request_status(conversation_id, request_id, platform)
522
+ )
523
+ self._poll_tasks[(conversation_id, request_id)] = poll_task
524
+
525
+ # Create timeout task if timeout is set
526
+ if timeout:
527
+ self._create_timeout_task(conversation_id, request_id, timeout)
528
+
529
+ return HumanLoopResult(
530
+ conversation_id=conversation_id,
531
+ request_id=request_id,
532
+ loop_type=HumanLoopType.CONVERSATION,
533
+ status=HumanLoopStatus.PENDING
534
+ )
535
+
536
+ except Exception as e:
537
+ logger.error(f"Failed to continue human-in-the-loop: {str(e)}")
538
+ self._update_request_status_error(conversation_id, request_id, str(e))
539
+
540
+
541
+ return HumanLoopResult(
542
+ conversation_id=conversation_id,
543
+ request_id=request_id,
544
+ loop_type=HumanLoopType.CONVERSATION,
545
+ status=HumanLoopStatus.ERROR,
546
+ error=str(e)
547
+ )
548
+
549
+ async def _poll_request_status(
550
+ self,
551
+ conversation_id: str,
552
+ request_id: str,
553
+ platform: str
554
+ ) -> None:
555
+ """Poll request status
556
+
557
+ Args:
558
+ conversation_id: Conversation identifier
559
+ request_id: Request identifier
560
+ platform: Platform identifier
561
+ """
562
+ try:
563
+ while True:
564
+ # Get request information
565
+ request_info = self._get_request(conversation_id, request_id)
566
+ if not request_info:
567
+ logger.warning(f"Polling stopped: Request '{request_id}' not found in conversation '{conversation_id}'")
568
+ return
569
+
570
+ # Stop polling if request is in final status
571
+ status = request_info.get("status")
572
+ if status not in [HumanLoopStatus.PENDING, HumanLoopStatus.INPROGRESS]:
573
+ return
574
+
575
+ # Send API request to get status
576
+ params = HumanLoopStatusParams(
577
+ conversation_id=conversation_id,
578
+ request_id=request_id,
579
+ platform=platform
580
+ ).model_dump()
581
+
582
+ response = await self._make_api_request(
583
+ endpoint="v1/humanloop/status",
584
+ method="GET",
585
+ params=params
586
+ )
587
+
588
+ # Parse response
589
+ status_response = HumanLoopStatusResponse(**response)
590
+
591
+ # Log error but continue polling if request fails
592
+ if not status_response.success:
593
+ logger.warning(f"Failed to get status: {status_response.error}")
594
+ await asyncio.sleep(self.poll_interval)
595
+ continue
596
+
597
+ # Parse status
598
+ try:
599
+ new_status = HumanLoopStatus(status_response.status)
600
+ except ValueError:
601
+ logger.warning(f"Unknown status value: {status_response.status}, using PENDING")
602
+ new_status = HumanLoopStatus.PENDING
603
+
604
+ # Update request information
605
+ request_key = (conversation_id, request_id)
606
+ if request_key in self._requests:
607
+ self._requests[request_key]["status"] = new_status
608
+
609
+ # Update response data
610
+ for field in ["response", "feedback", "responded_by", "responded_at", "error"]:
611
+ value = getattr(status_response, field, None)
612
+ if value is not None:
613
+ self._requests[request_key][field] = value
614
+
615
+
616
+ # Stop polling if request is in final status
617
+ if new_status not in [HumanLoopStatus.PENDING, HumanLoopStatus.INPROGRESS]:
618
+ return
619
+
620
+ # Wait for next polling interval
621
+ await asyncio.sleep(self.poll_interval)
622
+
623
+ except asyncio.CancelledError:
624
+ logger.info(f"Polling task cancelled: conversation '{conversation_id}', request '{request_id}'")
625
+ return
626
+ except Exception as e:
627
+ logger.error(f"Polling task error: {str(e)}")
628
+ return