agentex-sdk 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agentex/_constants.py CHANGED
@@ -6,9 +6,9 @@ RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
6
6
  OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to"
7
7
 
8
8
  # default timeout is 1 minute
9
- DEFAULT_TIMEOUT = httpx.Timeout(timeout=60, connect=5.0)
10
- DEFAULT_MAX_RETRIES = 2
11
- DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20)
9
+ DEFAULT_TIMEOUT = httpx.Timeout(timeout=300, connect=5.0)
10
+ DEFAULT_MAX_RETRIES = 0
11
+ DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=1000, max_keepalive_connections=1000)
12
12
 
13
13
  INITIAL_RETRY_DELAY = 0.5
14
14
  MAX_RETRY_DELAY = 8.0
agentex/_version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
2
 
3
3
  __title__ = "agentex"
4
- __version__ = "0.4.11" # x-release-please-version
4
+ __version__ = "0.4.13" # x-release-please-version
@@ -59,6 +59,7 @@ class ACPModule:
59
59
  start_to_close_timeout: timedelta = timedelta(seconds=5),
60
60
  heartbeat_timeout: timedelta = timedelta(seconds=5),
61
61
  retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
62
+ request: dict[str, Any] | None = None,
62
63
  ) -> Task:
63
64
  """
64
65
  Create a new task.
@@ -71,6 +72,7 @@ class ACPModule:
71
72
  start_to_close_timeout: The start to close timeout for the task.
72
73
  heartbeat_timeout: The heartbeat timeout for the task.
73
74
  retry_policy: The retry policy for the task.
75
+ request: Additional request context including headers to forward to the agent.
74
76
 
75
77
  Returns:
76
78
  The task entry.
@@ -85,6 +87,7 @@ class ACPModule:
85
87
  params=params,
86
88
  trace_id=trace_id,
87
89
  parent_span_id=parent_span_id,
90
+ request=request,
88
91
  ),
89
92
  response_type=Task,
90
93
  start_to_close_timeout=start_to_close_timeout,
@@ -99,6 +102,7 @@ class ACPModule:
99
102
  params=params,
100
103
  trace_id=trace_id,
101
104
  parent_span_id=parent_span_id,
105
+ request=request,
102
106
  )
103
107
 
104
108
  async def send_event(
@@ -112,15 +116,22 @@ class ACPModule:
112
116
  start_to_close_timeout: timedelta = timedelta(seconds=5),
113
117
  heartbeat_timeout: timedelta = timedelta(seconds=5),
114
118
  retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
119
+ request: dict[str, Any] | None = None,
115
120
  ) -> Event:
116
121
  """
117
122
  Send an event to a task.
118
123
 
119
124
  Args:
120
125
  task_id: The ID of the task to send the event to.
121
- data: The data to send to the event.
126
+ content: The content to send to the event.
122
127
  agent_id: The ID of the agent to send the event to.
123
128
  agent_name: The name of the agent to send the event to.
129
+ trace_id: The trace ID for the event.
130
+ parent_span_id: The parent span ID for the event.
131
+ start_to_close_timeout: The start to close timeout for the event.
132
+ heartbeat_timeout: The heartbeat timeout for the event.
133
+ retry_policy: The retry policy for the event.
134
+ request: Additional request context including headers to forward to the agent.
124
135
 
125
136
  Returns:
126
137
  The event entry.
@@ -135,6 +146,7 @@ class ACPModule:
135
146
  content=content,
136
147
  trace_id=trace_id,
137
148
  parent_span_id=parent_span_id,
149
+ request=request,
138
150
  ),
139
151
  response_type=None,
140
152
  start_to_close_timeout=start_to_close_timeout,
@@ -149,6 +161,7 @@ class ACPModule:
149
161
  content=content,
150
162
  trace_id=trace_id,
151
163
  parent_span_id=parent_span_id,
164
+ request=request,
152
165
  )
153
166
 
154
167
  async def send_message(
@@ -162,15 +175,22 @@ class ACPModule:
162
175
  start_to_close_timeout: timedelta = timedelta(seconds=5),
163
176
  heartbeat_timeout: timedelta = timedelta(seconds=5),
164
177
  retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
178
+ request: dict[str, Any] | None = None,
165
179
  ) -> List[TaskMessage]:
166
180
  """
167
181
  Send a message to a task.
168
182
 
169
183
  Args:
170
- task_id: The ID of the task to send the message to.
171
184
  content: The task message content to send to the task.
185
+ task_id: The ID of the task to send the message to.
172
186
  agent_id: The ID of the agent to send the message to.
173
187
  agent_name: The name of the agent to send the message to.
188
+ trace_id: The trace ID for the message.
189
+ parent_span_id: The parent span ID for the message.
190
+ start_to_close_timeout: The start to close timeout for the message.
191
+ heartbeat_timeout: The heartbeat timeout for the message.
192
+ retry_policy: The retry policy for the message.
193
+ request: Additional request context including headers to forward to the agent.
174
194
 
175
195
  Returns:
176
196
  The message entry.
@@ -185,6 +205,7 @@ class ACPModule:
185
205
  content=content,
186
206
  trace_id=trace_id,
187
207
  parent_span_id=parent_span_id,
208
+ request=request,
188
209
  ),
189
210
  response_type=TaskMessage,
190
211
  start_to_close_timeout=start_to_close_timeout,
@@ -199,32 +220,43 @@ class ACPModule:
199
220
  content=content,
200
221
  trace_id=trace_id,
201
222
  parent_span_id=parent_span_id,
223
+ request=request,
202
224
  )
203
225
 
204
226
  async def cancel_task(
205
227
  self,
206
228
  task_id: str | None = None,
207
229
  task_name: str | None = None,
230
+ agent_id: str | None = None,
231
+ agent_name: str | None = None,
208
232
  trace_id: str | None = None,
209
233
  parent_span_id: str | None = None,
210
234
  start_to_close_timeout: timedelta = timedelta(seconds=5),
211
235
  heartbeat_timeout: timedelta = timedelta(seconds=5),
212
236
  retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
237
+ request: dict[str, Any] | None = None,
213
238
  ) -> Task:
214
239
  """
215
- Cancel a task.
240
+ Cancel a task by sending cancel request to the agent that owns the task.
216
241
 
217
242
  Args:
218
- task_id: The ID of the task to cancel.
219
- task_name: The name of the task to cancel.
243
+ task_id: ID of the task to cancel.
244
+ task_name: Name of the task to cancel.
245
+ agent_id: ID of the agent that owns the task.
246
+ agent_name: Name of the agent that owns the task.
220
247
  trace_id: The trace ID for the task.
221
248
  parent_span_id: The parent span ID for the task.
222
249
  start_to_close_timeout: The start to close timeout for the task.
223
250
  heartbeat_timeout: The heartbeat timeout for the task.
224
251
  retry_policy: The retry policy for the task.
252
+ request: Additional request context including headers to forward to the agent.
225
253
 
226
254
  Returns:
227
255
  The task entry.
256
+
257
+ Raises:
258
+ ValueError: If neither agent_name nor agent_id is provided,
259
+ or if neither task_name nor task_id is provided
228
260
  """
229
261
  if in_temporal_workflow():
230
262
  return await ActivityHelpers.execute_activity(
@@ -232,8 +264,11 @@ class ACPModule:
232
264
  request=TaskCancelParams(
233
265
  task_id=task_id,
234
266
  task_name=task_name,
267
+ agent_id=agent_id,
268
+ agent_name=agent_name,
235
269
  trace_id=trace_id,
236
270
  parent_span_id=parent_span_id,
271
+ request=request,
237
272
  ),
238
273
  response_type=None,
239
274
  start_to_close_timeout=start_to_close_timeout,
@@ -244,6 +279,9 @@ class ACPModule:
244
279
  return await self._acp_service.task_cancel(
245
280
  task_id=task_id,
246
281
  task_name=task_name,
282
+ agent_id=agent_id,
283
+ agent_name=agent_name,
247
284
  trace_id=trace_id,
248
285
  parent_span_id=parent_span_id,
286
+ request=request,
249
287
  )
@@ -88,6 +88,7 @@ class OpenAIModule:
88
88
  mcp_timeout_seconds: int | None = None,
89
89
  input_guardrails: list[InputGuardrail] | None = None,
90
90
  output_guardrails: list[OutputGuardrail] | None = None,
91
+ max_turns: int | None = None,
91
92
  ) -> SerializableRunResult | RunResult:
92
93
  """
93
94
  Run an agent without streaming or TaskMessage creation.
@@ -114,6 +115,7 @@ class OpenAIModule:
114
115
  mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds.
115
116
  input_guardrails: Optional list of input guardrails to run on initial user input.
116
117
  output_guardrails: Optional list of output guardrails to run on final agent output.
118
+ max_turns: Maximum number of turns the agent can take. Uses Runner's default if None.
117
119
 
118
120
  Returns:
119
121
  Union[SerializableRunResult, RunResult]: SerializableRunResult when in Temporal, RunResult otherwise.
@@ -136,6 +138,7 @@ class OpenAIModule:
136
138
  mcp_timeout_seconds=mcp_timeout_seconds,
137
139
  input_guardrails=input_guardrails,
138
140
  output_guardrails=output_guardrails,
141
+ max_turns=max_turns,
139
142
  )
140
143
  return await ActivityHelpers.execute_activity(
141
144
  activity_name=OpenAIActivityName.RUN_AGENT,
@@ -163,6 +166,7 @@ class OpenAIModule:
163
166
  mcp_timeout_seconds=mcp_timeout_seconds,
164
167
  input_guardrails=input_guardrails,
165
168
  output_guardrails=output_guardrails,
169
+ max_turns=max_turns,
166
170
  )
167
171
 
168
172
  async def run_agent_auto_send(
@@ -191,6 +195,7 @@ class OpenAIModule:
191
195
  mcp_timeout_seconds: int | None = None,
192
196
  input_guardrails: list[InputGuardrail] | None = None,
193
197
  output_guardrails: list[OutputGuardrail] | None = None,
198
+ max_turns: int | None = None,
194
199
  ) -> SerializableRunResult | RunResult:
195
200
  """
196
201
  Run an agent with automatic TaskMessage creation.
@@ -216,6 +221,7 @@ class OpenAIModule:
216
221
  mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds.
217
222
  input_guardrails: Optional list of input guardrails to run on initial user input.
218
223
  output_guardrails: Optional list of output guardrails to run on final agent output.
224
+ max_turns: Maximum number of turns the agent can take. Uses Runner's default if None.
219
225
 
220
226
  Returns:
221
227
  Union[SerializableRunResult, RunResult]: SerializableRunResult when in Temporal, RunResult otherwise.
@@ -239,6 +245,7 @@ class OpenAIModule:
239
245
  mcp_timeout_seconds=mcp_timeout_seconds,
240
246
  input_guardrails=input_guardrails,
241
247
  output_guardrails=output_guardrails,
248
+ max_turns=max_turns,
242
249
  )
243
250
  return await ActivityHelpers.execute_activity(
244
251
  activity_name=OpenAIActivityName.RUN_AGENT_AUTO_SEND,
@@ -267,6 +274,7 @@ class OpenAIModule:
267
274
  mcp_timeout_seconds=mcp_timeout_seconds,
268
275
  input_guardrails=input_guardrails,
269
276
  output_guardrails=output_guardrails,
277
+ max_turns=max_turns,
270
278
  )
271
279
 
272
280
  async def run_agent_streamed(
@@ -291,6 +299,7 @@ class OpenAIModule:
291
299
  mcp_timeout_seconds: int | None = None,
292
300
  input_guardrails: list[InputGuardrail] | None = None,
293
301
  output_guardrails: list[OutputGuardrail] | None = None,
302
+ max_turns: int | None = None,
294
303
  ) -> RunResultStreaming:
295
304
  """
296
305
  Run an agent with streaming enabled but no TaskMessage creation.
@@ -320,6 +329,7 @@ class OpenAIModule:
320
329
  mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds.
321
330
  input_guardrails: Optional list of input guardrails to run on initial user input.
322
331
  output_guardrails: Optional list of output guardrails to run on final agent output.
332
+ max_turns: Maximum number of turns the agent can take. Uses Runner's default if None.
323
333
 
324
334
  Returns:
325
335
  RunResultStreaming: The result of the agent run with streaming.
@@ -352,6 +362,7 @@ class OpenAIModule:
352
362
  mcp_timeout_seconds=mcp_timeout_seconds,
353
363
  input_guardrails=input_guardrails,
354
364
  output_guardrails=output_guardrails,
365
+ max_turns=max_turns,
355
366
  )
356
367
 
357
368
  async def run_agent_streamed_auto_send(
@@ -380,6 +391,7 @@ class OpenAIModule:
380
391
  mcp_timeout_seconds: int | None = None,
381
392
  input_guardrails: list[InputGuardrail] | None = None,
382
393
  output_guardrails: list[OutputGuardrail] | None = None,
394
+ max_turns: int | None = None,
383
395
  ) -> SerializableRunResultStreaming | RunResultStreaming:
384
396
  """
385
397
  Run an agent with streaming enabled and automatic TaskMessage creation.
@@ -405,6 +417,7 @@ class OpenAIModule:
405
417
  output_type: Optional output type.
406
418
  tool_use_behavior: Optional tool use behavior.
407
419
  mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds.
420
+ max_turns: Maximum number of turns the agent can take. Uses Runner's default if None.
408
421
 
409
422
  Returns:
410
423
  Union[SerializableRunResultStreaming, RunResultStreaming]: SerializableRunResultStreaming when in Temporal, RunResultStreaming otherwise.
@@ -428,6 +441,7 @@ class OpenAIModule:
428
441
  mcp_timeout_seconds=mcp_timeout_seconds,
429
442
  input_guardrails=input_guardrails,
430
443
  output_guardrails=output_guardrails,
444
+ max_turns=max_turns
431
445
  )
432
446
  return await ActivityHelpers.execute_activity(
433
447
  activity_name=OpenAIActivityName.RUN_AGENT_STREAMED_AUTO_SEND,
@@ -456,4 +470,5 @@ class OpenAIModule:
456
470
  mcp_timeout_seconds=mcp_timeout_seconds,
457
471
  input_guardrails=input_guardrails,
458
472
  output_guardrails=output_guardrails,
473
+ max_turns=max_turns,
459
474
  )
@@ -229,9 +229,12 @@ def merge_deployment_configs(
229
229
  all_env_vars[EnvVarKeys.AUTH_PRINCIPAL_B64.value] = encoded_principal
230
230
  else:
231
231
  raise DeploymentError(f"Auth principal unable to be encoded for agent_env_config: {agent_env_config}")
232
-
232
+
233
+ logger.info(f"Defined agent helm overrides: {agent_env_config.helm_overrides}")
234
+ logger.info(f"Before-merge helm values: {helm_values}")
233
235
  if agent_env_config.helm_overrides:
234
236
  _deep_merge(helm_values, agent_env_config.helm_overrides)
237
+ logger.info(f"After-merge helm values: {helm_values}")
235
238
 
236
239
  # Set final environment variables
237
240
  # Environment variable precedence: manifest -> environments.yaml -> secrets (highest)
@@ -52,7 +52,7 @@ environments:
52
52
  limits:
53
53
  cpu: "1000m"
54
54
  memory: "2Gi"
55
- temporal:
55
+ temporal-worker:
56
56
  enabled: true
57
57
  replicaCount: 2
58
58
  resources:
@@ -9,6 +9,10 @@ from agentex.types.task import Task
9
9
  from agentex.types.task_message import TaskMessage
10
10
  from agentex.types.task_message_content import TaskMessageContent
11
11
  from agentex.types.task_message_content_param import TaskMessageContentParam
12
+ from agentex.types.agent_rpc_params import (
13
+ ParamsCancelTaskRequest as RpcParamsCancelTaskRequest,
14
+ ParamsSendEventRequest as RpcParamsSendEventRequest,
15
+ )
12
16
 
13
17
  logger = make_logger(__name__)
14
18
 
@@ -30,6 +34,7 @@ class ACPService:
30
34
  params: dict[str, Any] | None = None,
31
35
  trace_id: str | None = None,
32
36
  parent_span_id: str | None = None,
37
+ request: dict[str, Any] | None = None,
33
38
  ) -> Task:
34
39
  trace = self._tracer.trace(trace_id=trace_id)
35
40
  async with trace.span(
@@ -43,6 +48,10 @@ class ACPService:
43
48
  },
44
49
  ) as span:
45
50
  heartbeat_if_in_workflow("task create")
51
+
52
+ # Extract headers from request; pass-through to agent
53
+ extra_headers = request.get("headers") if request else None
54
+
46
55
  if agent_name:
47
56
  json_rpc_response = await self._agentex_client.agents.rpc_by_name(
48
57
  agent_name=agent_name,
@@ -51,6 +60,7 @@ class ACPService:
51
60
  "name": name,
52
61
  "params": params,
53
62
  },
63
+ extra_headers=extra_headers,
54
64
  )
55
65
  elif agent_id:
56
66
  json_rpc_response = await self._agentex_client.agents.rpc(
@@ -60,6 +70,7 @@ class ACPService:
60
70
  "name": name,
61
71
  "params": params,
62
72
  },
73
+ extra_headers=extra_headers,
63
74
  )
64
75
  else:
65
76
  raise ValueError("Either agent_name or agent_id must be provided")
@@ -78,6 +89,7 @@ class ACPService:
78
89
  task_name: str | None = None,
79
90
  trace_id: str | None = None,
80
91
  parent_span_id: str | None = None,
92
+ request: dict[str, Any] | None = None,
81
93
  ) -> List[TaskMessage]:
82
94
  trace = self._tracer.trace(trace_id=trace_id)
83
95
  async with trace.span(
@@ -92,6 +104,10 @@ class ACPService:
92
104
  },
93
105
  ) as span:
94
106
  heartbeat_if_in_workflow("message send")
107
+
108
+ # Extract headers from request; pass-through to agent
109
+ extra_headers = request.get("headers") if request else None
110
+
95
111
  if agent_name:
96
112
  json_rpc_response = await self._agentex_client.agents.rpc_by_name(
97
113
  agent_name=agent_name,
@@ -101,6 +117,7 @@ class ACPService:
101
117
  "content": cast(TaskMessageContentParam, content.model_dump()),
102
118
  "stream": False,
103
119
  },
120
+ extra_headers=extra_headers,
104
121
  )
105
122
  elif agent_id:
106
123
  json_rpc_response = await self._agentex_client.agents.rpc(
@@ -111,12 +128,13 @@ class ACPService:
111
128
  "content": cast(TaskMessageContentParam, content.model_dump()),
112
129
  "stream": False,
113
130
  },
131
+ extra_headers=extra_headers,
114
132
  )
115
133
  else:
116
134
  raise ValueError("Either agent_name or agent_id must be provided")
117
135
 
118
136
  task_messages: List[TaskMessage] = []
119
- logger.info(f"json_rpc_response: {json_rpc_response}")
137
+ logger.info("json_rpc_response: %s", json_rpc_response)
120
138
  if isinstance(json_rpc_response.result, list):
121
139
  for message in json_rpc_response.result:
122
140
  task_message = TaskMessage.model_validate(message)
@@ -137,6 +155,7 @@ class ACPService:
137
155
  task_name: str | None = None,
138
156
  trace_id: str | None = None,
139
157
  parent_span_id: str | None = None,
158
+ request: dict[str, Any] | None = None,
140
159
  ) -> Event:
141
160
  trace = self._tracer.trace(trace_id=trace_id)
142
161
  async with trace.span(
@@ -146,27 +165,33 @@ class ACPService:
146
165
  "agent_id": agent_id,
147
166
  "agent_name": agent_name,
148
167
  "task_id": task_id,
168
+ "task_name": task_name,
149
169
  "content": content,
150
170
  },
151
171
  ) as span:
152
172
  heartbeat_if_in_workflow("event send")
173
+
174
+ # Extract headers from request; pass-through to agent
175
+ extra_headers = request.get("headers") if request else None
176
+
177
+ rpc_event_params: RpcParamsSendEventRequest = {
178
+ "task_id": task_id,
179
+ "task_name": task_name,
180
+ "content": cast(TaskMessageContentParam, content.model_dump()),
181
+ }
153
182
  if agent_name:
154
183
  json_rpc_response = await self._agentex_client.agents.rpc_by_name(
155
184
  agent_name=agent_name,
156
185
  method="event/send",
157
- params={
158
- "task_id": task_id,
159
- "content": cast(TaskMessageContentParam, content.model_dump()),
160
- },
186
+ params=rpc_event_params,
187
+ extra_headers=extra_headers,
161
188
  )
162
189
  elif agent_id:
163
190
  json_rpc_response = await self._agentex_client.agents.rpc(
164
191
  agent_id=agent_id,
165
192
  method="event/send",
166
- params={
167
- "task_id": task_id,
168
- "content": cast(TaskMessageContentParam, content.model_dump()),
169
- },
193
+ params=rpc_event_params,
194
+ extra_headers=extra_headers,
170
195
  )
171
196
  else:
172
197
  raise ValueError("Either agent_name or agent_id must be provided")
@@ -180,9 +205,38 @@ class ACPService:
180
205
  self,
181
206
  task_id: str | None = None,
182
207
  task_name: str | None = None,
208
+ agent_id: str | None = None,
209
+ agent_name: str | None = None,
183
210
  trace_id: str | None = None,
184
211
  parent_span_id: str | None = None,
212
+ request: dict[str, Any] | None = None,
185
213
  ) -> Task:
214
+ """
215
+ Cancel a task by sending cancel request to the agent that owns the task.
216
+
217
+ Args:
218
+ task_id: ID of the task to cancel (passed to agent in params)
219
+ task_name: Name of the task to cancel (passed to agent in params)
220
+ agent_id: ID of the agent that owns the task
221
+ agent_name: Name of the agent that owns the task
222
+ trace_id: Trace ID for tracing
223
+ parent_span_id: Parent span ID for tracing
224
+ request: Additional request context including headers to forward to the agent
225
+
226
+ Returns:
227
+ Task entry representing the cancelled task
228
+
229
+ Raises:
230
+ ValueError: If neither agent_name nor agent_id is provided,
231
+ or if neither task_name nor task_id is provided
232
+ """
233
+ # Require agent identification
234
+ if not agent_name and not agent_id:
235
+ raise ValueError("Either agent_name or agent_id must be provided to identify the agent that owns the task")
236
+
237
+ # Require task identification
238
+ if not task_name and not task_id:
239
+ raise ValueError("Either task_name or task_id must be provided to identify the task to cancel")
186
240
  trace = self._tracer.trace(trace_id=trace_id)
187
241
  async with trace.span(
188
242
  parent_id=parent_span_id,
@@ -190,27 +244,38 @@ class ACPService:
190
244
  input={
191
245
  "task_id": task_id,
192
246
  "task_name": task_name,
247
+ "agent_id": agent_id,
248
+ "agent_name": agent_name,
193
249
  },
194
250
  ) as span:
195
251
  heartbeat_if_in_workflow("task cancel")
252
+
253
+ # Extract headers from request; pass-through to agent
254
+ extra_headers = request.get("headers") if request else None
255
+
256
+ # Build params for the agent (task identification)
257
+ params: RpcParamsCancelTaskRequest = {}
258
+ if task_id:
259
+ params["task_id"] = task_id
196
260
  if task_name:
261
+ params["task_name"] = task_name
262
+
263
+ # Send cancel request to the correct agent
264
+ if agent_name:
197
265
  json_rpc_response = await self._agentex_client.agents.rpc_by_name(
198
- agent_name=task_name,
266
+ agent_name=agent_name,
199
267
  method="task/cancel",
200
- params={
201
- "task_name": task_name,
202
- },
268
+ params=params,
269
+ extra_headers=extra_headers,
203
270
  )
204
- elif task_id:
271
+ else: # agent_id is provided (validated above)
272
+ assert agent_id is not None
205
273
  json_rpc_response = await self._agentex_client.agents.rpc(
206
- agent_id=task_id,
274
+ agent_id=agent_id,
207
275
  method="task/cancel",
208
- params={
209
- "task_id": task_id,
210
- },
276
+ params=params,
277
+ extra_headers=extra_headers,
211
278
  )
212
- else:
213
- raise ValueError("Either task_name or task_id must be provided")
214
279
 
215
280
  task_entry = Task.model_validate(json_rpc_response.result)
216
281
  if span: