pycityagent 2.0.0a17__tar.gz → 2.0.0a19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/PKG-INFO +2 -1
  2. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/agent.py +119 -93
  3. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/economy/econ_client.py +2 -2
  4. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/simulator.py +2 -2
  5. pycityagent-2.0.0a19/pycityagent/metrics/__init__.py +5 -0
  6. pycityagent-2.0.0a19/pycityagent/metrics/mlflow_client.py +109 -0
  7. pycityagent-2.0.0a19/pycityagent/metrics/utils/const.py +0 -0
  8. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/simulation/agentgroup.py +56 -21
  9. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/simulation/simulation.py +102 -43
  10. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/workflow/__init__.py +5 -3
  11. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/workflow/block.py +2 -3
  12. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/workflow/tool.py +51 -2
  13. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pyproject.toml +2 -1
  14. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/README.md +0 -0
  15. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/__init__.py +0 -0
  16. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/economy/__init__.py +0 -0
  17. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/__init__.py +0 -0
  18. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/interact/__init__.py +0 -0
  19. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/interact/interact.py +0 -0
  20. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/message/__init__.py +0 -0
  21. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sence/__init__.py +0 -0
  22. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sence/static.py +0 -0
  23. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sidecar/__init__.py +0 -0
  24. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sidecar/sidecarv2.py +0 -0
  25. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/__init__.py +0 -0
  26. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/aoi_service.py +0 -0
  27. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/client.py +0 -0
  28. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/clock_service.py +0 -0
  29. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/economy_services.py +0 -0
  30. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/lane_service.py +0 -0
  31. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/light_service.py +0 -0
  32. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/person_service.py +0 -0
  33. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/road_service.py +0 -0
  34. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/sim_env.py +0 -0
  35. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/sim/social_service.py +0 -0
  36. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/utils/__init__.py +0 -0
  37. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/utils/base64.py +0 -0
  38. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/utils/const.py +0 -0
  39. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/utils/geojson.py +0 -0
  40. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/utils/grpc.py +0 -0
  41. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/utils/map_utils.py +0 -0
  42. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/utils/port.py +0 -0
  43. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/environment/utils/protobuf.py +0 -0
  44. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/llm/__init__.py +0 -0
  45. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/llm/embedding.py +0 -0
  46. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/llm/llm.py +0 -0
  47. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/llm/llmconfig.py +0 -0
  48. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/llm/utils.py +0 -0
  49. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/memory/__init__.py +0 -0
  50. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/memory/const.py +0 -0
  51. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/memory/memory.py +0 -0
  52. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/memory/memory_base.py +0 -0
  53. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/memory/profile.py +0 -0
  54. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/memory/self_define.py +0 -0
  55. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/memory/state.py +0 -0
  56. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/memory/utils.py +0 -0
  57. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/message/__init__.py +0 -0
  58. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/message/messager.py +0 -0
  59. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/simulation/__init__.py +0 -0
  60. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/survey/__init__.py +0 -0
  61. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/survey/manager.py +0 -0
  62. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/survey/models.py +0 -0
  63. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/utils/__init__.py +0 -0
  64. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/utils/avro_schema.py +0 -0
  65. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/utils/decorators.py +0 -0
  66. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/utils/parsers/__init__.py +0 -0
  67. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/utils/parsers/code_block_parser.py +0 -0
  68. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/utils/parsers/json_parser.py +0 -0
  69. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/utils/parsers/parser_base.py +0 -0
  70. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/utils/survey_util.py +0 -0
  71. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/workflow/prompt.py +0 -0
  72. {pycityagent-2.0.0a17 → pycityagent-2.0.0a19}/pycityagent/workflow/trigger.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pycityagent
3
- Version: 2.0.0a17
3
+ Version: 2.0.0a19
4
4
  Summary: LLM-based城市环境agent构建库
5
5
  License: MIT
6
6
  Author: Yuwei Yan
@@ -26,6 +26,7 @@ Requires-Dist: gradio (>=5.7.1,<6.0.0)
26
26
  Requires-Dist: grpcio (==1.67.1)
27
27
  Requires-Dist: langchain-core (>=0.3.28,<0.4.0)
28
28
  Requires-Dist: matplotlib (==3.8.3)
29
+ Requires-Dist: mlflow (>=2.19.0,<3.0.0)
29
30
  Requires-Dist: mosstool (==1.0.24)
30
31
  Requires-Dist: networkx (==3.2.1)
31
32
  Requires-Dist: numpy (>=1.20.0,<2.0.0)
@@ -1,30 +1,28 @@
1
1
  """智能体模板类及其定义"""
2
2
 
3
- from abc import ABC, abstractmethod
4
3
  import asyncio
5
- from uuid import UUID
6
- from copy import deepcopy
7
- from datetime import datetime
8
- from enum import Enum
9
4
  import logging
10
5
  import random
11
6
  import uuid
12
- from typing import Dict, List, Optional,Any
7
+ from abc import ABC, abstractmethod
8
+ from copy import deepcopy
9
+ from datetime import datetime
10
+ from enum import Enum
11
+ from typing import Any, Dict, List, Optional
12
+ from uuid import UUID
13
13
 
14
14
  import fastavro
15
-
16
- from pycityagent.environment.sim.person_service import PersonService
17
15
  from mosstool.util.format_converter import dict2pb
18
16
  from pycityproto.city.person.v2 import person_pb2 as person_pb2
19
- from pycityagent.utils import process_survey_for_llm
20
-
21
- from pycityagent.message.messager import Messager
22
- from pycityagent.utils import SURVEY_SCHEMA, DIALOG_SCHEMA
23
17
 
24
18
  from .economy import EconomyClient
25
19
  from .environment import Simulator
20
+ from .environment.sim.person_service import PersonService
26
21
  from .llm import LLM
27
22
  from .memory import Memory
23
+ from .message.messager import Messager
24
+ from .metrics import MlflowClient
25
+ from .utils import DIALOG_SCHEMA, SURVEY_SCHEMA, process_survey_for_llm
28
26
 
29
27
  logger = logging.getLogger("pycityagent")
30
28
 
@@ -55,6 +53,7 @@ class Agent(ABC):
55
53
  economy_client: Optional[EconomyClient] = None,
56
54
  messager: Optional[Messager] = None,
57
55
  simulator: Optional[Simulator] = None,
56
+ mlflow_client: Optional[MlflowClient] = None,
58
57
  memory: Optional[Memory] = None,
59
58
  avro_file: Optional[Dict[str, str]] = None,
60
59
  ) -> None:
@@ -68,6 +67,7 @@ class Agent(ABC):
68
67
  economy_client (EconomyClient): The `EconomySim` client. Defaults to None.
69
68
  messager (Messager, optional): The messager object. Defaults to None.
70
69
  simulator (Simulator, optional): The simulator object. Defaults to None.
70
+ mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
71
71
  memory (Memory, optional): The memory of the agent. Defaults to None.
72
72
  avro_file (Dict[str, str], optional): The avro file of the agent. Defaults to None.
73
73
  """
@@ -78,6 +78,7 @@ class Agent(ABC):
78
78
  self._economy_client = economy_client
79
79
  self._messager = messager
80
80
  self._simulator = simulator
81
+ self._mlflow_client = mlflow_client
81
82
  self._memory = memory
82
83
  self._exp_id = -1
83
84
  self._agent_id = -1
@@ -112,6 +113,12 @@ class Agent(ABC):
112
113
  """
113
114
  self._simulator = simulator
114
115
 
116
+ def set_mlflow_client(self, mlflow_client: MlflowClient):
117
+ """
118
+ Set the mlflow_client of the agent.
119
+ """
120
+ self._mlflow_client = mlflow_client
121
+
115
122
  def set_economy_client(self, economy_client: EconomyClient):
116
123
  """
117
124
  Set the economy_client of the agent.
@@ -164,6 +171,15 @@ class Agent(ABC):
164
171
  )
165
172
  return self._economy_client
166
173
 
174
+ @property
175
+ def mlflow_client(self):
176
+ """The Agent's MlflowClient"""
177
+ if self._mlflow_client is None:
178
+ raise RuntimeError(
179
+ f"MlflowClient access before assignment, please `set_mlflow_client` first!"
180
+ )
181
+ return self._mlflow_client
182
+
167
183
  @property
168
184
  def memory(self):
169
185
  """The Agent's Memory"""
@@ -218,19 +234,21 @@ class Agent(ABC):
218
234
  response = await self._llm_client.atext_request(dialog) # type:ignore
219
235
 
220
236
  return response # type:ignore
221
-
237
+
222
238
  async def _process_survey(self, survey: dict):
223
239
  survey_response = await self.generate_user_survey_response(survey)
224
240
  if self._avro_file is None:
225
241
  return
226
- response_to_avro = [{
227
- "id": self._uuid,
228
- "day": await self.simulator.get_simulator_day(),
229
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
230
- "survey_id": survey["id"],
231
- "result": survey_response,
232
- "created_at": int(datetime.now().timestamp() * 1000),
233
- }]
242
+ response_to_avro = [
243
+ {
244
+ "id": self._uuid,
245
+ "day": await self.simulator.get_simulator_day(),
246
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
247
+ "survey_id": survey["id"],
248
+ "result": survey_response,
249
+ "created_at": int(datetime.now().timestamp() * 1000),
250
+ }
251
+ ]
234
252
  with open(self._avro_file["survey"], "a+b") as f:
235
253
  fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
236
254
 
@@ -270,28 +288,32 @@ class Agent(ABC):
270
288
  response = await self._llm_client.atext_request(dialog) # type:ignore
271
289
 
272
290
  return response # type:ignore
273
-
291
+
274
292
  async def _process_interview(self, payload: dict):
275
- auros = [{
276
- "id": self._uuid,
277
- "day": await self.simulator.get_simulator_day(),
278
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
279
- "type": 2,
280
- "speaker": "user",
281
- "content": payload["content"],
282
- "created_at": int(datetime.now().timestamp() * 1000),
283
- }]
293
+ auros = [
294
+ {
295
+ "id": self._uuid,
296
+ "day": await self.simulator.get_simulator_day(),
297
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
298
+ "type": 2,
299
+ "speaker": "user",
300
+ "content": payload["content"],
301
+ "created_at": int(datetime.now().timestamp() * 1000),
302
+ }
303
+ ]
284
304
  question = payload["content"]
285
305
  response = await self.generate_user_chat_response(question)
286
- auros.append({
287
- "id": self._uuid,
288
- "day": await self.simulator.get_simulator_day(),
289
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
290
- "type": 2,
291
- "speaker": "",
292
- "content": response,
293
- "created_at": int(datetime.now().timestamp() * 1000),
294
- })
306
+ auros.append(
307
+ {
308
+ "id": self._uuid,
309
+ "day": await self.simulator.get_simulator_day(),
310
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
311
+ "type": 2,
312
+ "speaker": "",
313
+ "content": response,
314
+ "created_at": int(datetime.now().timestamp() * 1000),
315
+ }
316
+ )
295
317
  if self._avro_file is None:
296
318
  return
297
319
  with open(self._avro_file["dialog"], "a+b") as f:
@@ -303,15 +325,17 @@ class Agent(ABC):
303
325
  return resp
304
326
 
305
327
  async def _process_agent_chat(self, payload: dict):
306
- auros = [{
307
- "id": self._uuid,
308
- "day": payload["day"],
309
- "t": payload["t"],
310
- "type": 1,
311
- "speaker": payload["from"],
312
- "content": payload["content"],
313
- "created_at": int(datetime.now().timestamp() * 1000),
314
- }]
328
+ auros = [
329
+ {
330
+ "id": self._uuid,
331
+ "day": payload["day"],
332
+ "t": payload["t"],
333
+ "type": 1,
334
+ "speaker": payload["from"],
335
+ "content": payload["content"],
336
+ "created_at": int(datetime.now().timestamp() * 1000),
337
+ }
338
+ ]
315
339
  asyncio.create_task(self.process_agent_chat_response(payload))
316
340
  if self._avro_file is None:
317
341
  return
@@ -341,18 +365,14 @@ class Agent(ABC):
341
365
  raise NotImplementedError
342
366
 
343
367
  # MQTT send message
344
- async def _send_message(
345
- self, to_agent_uuid: str, payload: dict, sub_topic: str
346
- ):
368
+ async def _send_message(self, to_agent_uuid: str, payload: dict, sub_topic: str):
347
369
  """通过 Messager 发送消息"""
348
370
  if self._messager is None:
349
371
  raise RuntimeError("Messager is not set")
350
372
  topic = f"exps/{self._exp_id}/agents/{to_agent_uuid}/{sub_topic}"
351
373
  await self._messager.send_message(topic, payload)
352
374
 
353
- async def send_message_to_agent(
354
- self, to_agent_uuid: str, content: str
355
- ):
375
+ async def send_message_to_agent(self, to_agent_uuid: str, content: str):
356
376
  """通过 Messager 发送消息"""
357
377
  if self._messager is None:
358
378
  raise RuntimeError("Messager is not set")
@@ -364,15 +384,17 @@ class Agent(ABC):
364
384
  "t": await self.simulator.get_simulator_second_from_start_of_day(),
365
385
  }
366
386
  await self._send_message(to_agent_uuid, payload, "agent-chat")
367
- auros = [{
368
- "id": self._uuid,
369
- "day": await self.simulator.get_simulator_day(),
370
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
371
- "type": 1,
372
- "speaker": self._uuid,
373
- "content": content,
374
- "created_at": int(datetime.now().timestamp() * 1000),
375
- }]
387
+ auros = [
388
+ {
389
+ "id": self._uuid,
390
+ "day": await self.simulator.get_simulator_day(),
391
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
392
+ "type": 1,
393
+ "speaker": self._uuid,
394
+ "content": content,
395
+ "created_at": int(datetime.now().timestamp() * 1000),
396
+ }
397
+ ]
376
398
  if self._avro_file is None:
377
399
  return
378
400
  with open(self._avro_file["dialog"], "a+b") as f:
@@ -403,20 +425,22 @@ class CitizenAgent(Agent):
403
425
  name: str,
404
426
  llm_client: Optional[LLM] = None,
405
427
  simulator: Optional[Simulator] = None,
428
+ mlflow_client: Optional[MlflowClient] = None,
406
429
  memory: Optional[Memory] = None,
407
430
  economy_client: Optional[EconomyClient] = None,
408
431
  messager: Optional[Messager] = None,
409
432
  avro_file: Optional[dict] = None,
410
433
  ) -> None:
411
434
  super().__init__(
412
- name,
413
- AgentType.Citizen,
414
- llm_client,
415
- economy_client,
416
- messager,
417
- simulator,
418
- memory,
419
- avro_file,
435
+ name=name,
436
+ type=AgentType.Citizen,
437
+ llm_client=llm_client,
438
+ economy_client=economy_client,
439
+ messager=messager,
440
+ simulator=simulator,
441
+ mlflow_client=mlflow_client,
442
+ memory=memory,
443
+ avro_file=avro_file,
420
444
  )
421
445
 
422
446
  async def bind_to_simulator(self):
@@ -464,9 +488,7 @@ class CitizenAgent(Agent):
464
488
  )
465
489
  person_id = resp["person_id"]
466
490
  await memory.update("id", person_id, protect_llm_read_only_fields=False)
467
- logger.debug(
468
- f"Binding to Person `{person_id}` just added to Simulator"
469
- )
491
+ logger.debug(f"Binding to Person `{person_id}` just added to Simulator")
470
492
  # 防止模拟器还没有到prepare阶段导致get_person出错
471
493
  self._has_bound_to_simulator = True
472
494
  self._agent_id = person_id
@@ -517,24 +539,26 @@ class InstitutionAgent(Agent):
517
539
  name: str,
518
540
  llm_client: Optional[LLM] = None,
519
541
  simulator: Optional[Simulator] = None,
542
+ mlflow_client: Optional[MlflowClient] = None,
520
543
  memory: Optional[Memory] = None,
521
544
  economy_client: Optional[EconomyClient] = None,
522
545
  messager: Optional[Messager] = None,
523
546
  avro_file: Optional[dict] = None,
524
547
  ) -> None:
525
548
  super().__init__(
526
- name,
527
- AgentType.Institution,
528
- llm_client,
529
- economy_client,
530
- messager,
531
- simulator,
532
- memory,
533
- avro_file,
549
+ name=name,
550
+ type=AgentType.Institution,
551
+ llm_client=llm_client,
552
+ economy_client=economy_client,
553
+ mlflow_client=mlflow_client,
554
+ messager=messager,
555
+ simulator=simulator,
556
+ memory=memory,
557
+ avro_file=avro_file,
534
558
  )
535
559
  # 添加响应收集器
536
560
  self._gather_responses: Dict[str, asyncio.Future] = {}
537
-
561
+
538
562
  async def bind_to_simulator(self):
539
563
  await self._bind_to_economy()
540
564
 
@@ -624,22 +648,24 @@ class InstitutionAgent(Agent):
624
648
  """处理收到的消息,识别发送者"""
625
649
  content = payload["content"]
626
650
  sender_id = payload["from"]
627
-
651
+
628
652
  # 将响应存储到对应的Future中
629
653
  response_key = str(sender_id)
630
654
  if response_key in self._gather_responses:
631
- self._gather_responses[response_key].set_result({
632
- "from": sender_id,
633
- "content": content,
634
- })
655
+ self._gather_responses[response_key].set_result(
656
+ {
657
+ "from": sender_id,
658
+ "content": content,
659
+ }
660
+ )
635
661
 
636
662
  async def gather_messages(self, agent_uuids: list[str], target: str) -> List[dict]:
637
663
  """从多个智能体收集消息
638
-
664
+
639
665
  Args:
640
666
  agent_uuids: 目标智能体UUID列表
641
667
  target: 要收集的信息类型
642
-
668
+
643
669
  Returns:
644
670
  List[dict]: 收集到的所有响应
645
671
  """
@@ -648,7 +674,7 @@ class InstitutionAgent(Agent):
648
674
  for agent_uuid in agent_uuids:
649
675
  futures[agent_uuid] = asyncio.Future()
650
676
  self._gather_responses[agent_uuid] = futures[agent_uuid]
651
-
677
+
652
678
  # 发送gather请求
653
679
  payload = {
654
680
  "from": self._uuid,
@@ -656,7 +682,7 @@ class InstitutionAgent(Agent):
656
682
  }
657
683
  for agent_uuid in agent_uuids:
658
684
  await self._send_message(agent_uuid, payload, "gather")
659
-
685
+
660
686
  try:
661
687
  # 等待所有响应
662
688
  responses = await asyncio.gather(*futures.values())
@@ -308,11 +308,11 @@ class EconomyClient:
308
308
  # current agent ids and org ids
309
309
  return (list(response.agent_ids), list(response.org_ids))
310
310
 
311
- async def get_org_entity_ids(self, org_type: economyv2.OrgType)->list[int]:
311
+ async def get_org_entity_ids(self, org_type: economyv2.OrgType) -> list[int]:
312
312
  request = org_service.GetOrgEntityIdsRequest(
313
313
  type=org_type,
314
314
  )
315
315
  response: org_service.GetOrgEntityIdsResponse = (
316
316
  await self._aio_stub.GetOrgEntityIds(request)
317
317
  )
318
- return list(response.org_ids)
318
+ return list(response.org_ids)
@@ -5,7 +5,7 @@ import logging
5
5
  import os
6
6
  from collections.abc import Sequence
7
7
  from datetime import datetime, timedelta
8
- from typing import Any, Optional, Tuple, Union, cast
8
+ from typing import Any, Optional, Union, cast
9
9
 
10
10
  from mosstool.type import TripMode
11
11
  from mosstool.util.format_converter import coll2pb
@@ -28,7 +28,7 @@ class Simulator:
28
28
  - Simulator Class
29
29
  """
30
30
 
31
- def __init__(self, config, secure: bool = False) -> None:
31
+ def __init__(self, config:dict, secure: bool = False) -> None:
32
32
  self.config = config
33
33
  """
34
34
  - 模拟器配置
@@ -0,0 +1,5 @@
1
+ from .mlflow_client import MlflowClient
2
+
3
+ __all__ = [
4
+ "MlflowClient",
5
+ ]
@@ -0,0 +1,109 @@
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import uuid
5
+ from collections.abc import Sequence
6
+ from typing import Any, Optional, Union
7
+
8
+ import mlflow
9
+ from mlflow.entities import (Dataset, DatasetInput, Document, Experiment,
10
+ ExperimentTag, FileInfo, InputTag, LifecycleStage,
11
+ LiveSpan, Metric, NoOpSpan, Param, Run, RunData,
12
+ RunInfo, RunInputs, RunStatus, RunTag, SourceType,
13
+ Span, SpanEvent, SpanStatus, SpanStatusCode,
14
+ SpanType, Trace, TraceData, TraceInfo, ViewType)
15
+
16
+ from ..utils.decorators import lock_decorator
17
+
18
+ logger = logging.getLogger("mlflow")
19
+
20
+
21
+ class MlflowClient:
22
+ """
23
+ - Mlflow client
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ config: dict,
29
+ mlflow_run_name: Optional[str] = None,
30
+ experiment_name: Optional[str] = None,
31
+ experiment_description: Optional[str] = None,
32
+ experiment_tags: Optional[dict[str, Any]] = None,
33
+ ) -> None:
34
+ os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
35
+ os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
36
+ self._mlflow_uri = uri = config["mlflow_uri"]
37
+ self._client = client = mlflow.MlflowClient(tracking_uri=uri)
38
+ self._run_uuid = run_uuid = str(uuid.uuid4())
39
+ self._lock = asyncio.Lock()
40
+ # run name
41
+ if mlflow_run_name is None:
42
+ mlflow_run_name = f"exp_{run_uuid}"
43
+
44
+ # exp name
45
+ if experiment_name is None:
46
+ experiment_name = f"run_{run_uuid}"
47
+
48
+ # tags
49
+ if experiment_tags is None:
50
+ experiment_tags = {}
51
+ if experiment_description is not None:
52
+ experiment_tags["mlflow.note.content"] = experiment_description
53
+
54
+ try:
55
+ self._experiment_id = experiment_id = client.create_experiment(
56
+ name=experiment_name,
57
+ tags=experiment_tags,
58
+ )
59
+ except Exception as e:
60
+ experiment = client.get_experiment_by_name(experiment_name)
61
+ if experiment is None:
62
+ raise e
63
+ self._experiment_id = experiment_id = experiment.experiment_id
64
+
65
+ self._run = run = client.create_run(
66
+ experiment_id=experiment_id, run_name=mlflow_run_name
67
+ )
68
+ self._run_id = run.info.run_id
69
+
70
+ @property
71
+ def client(
72
+ self,
73
+ ) -> mlflow.MlflowClient:
74
+ return self._client
75
+
76
+ @property
77
+ def run_id(
78
+ self,
79
+ ) -> str:
80
+ return self._run_id
81
+
82
+ @lock_decorator
83
+ async def log_batch(
84
+ self,
85
+ metrics: Sequence[Metric] = (),
86
+ params: Sequence[Param] = (),
87
+ tags: Sequence[RunTag] = (),
88
+ ):
89
+ self.client.log_batch(
90
+ run_id=self.run_id, metrics=metrics, params=params, tags=tags
91
+ )
92
+
93
+ @lock_decorator
94
+ async def log_metric(
95
+ self,
96
+ key: str,
97
+ value: float,
98
+ step: Optional[int] = None,
99
+ timestamp: Optional[int] = None,
100
+ ):
101
+ if timestamp is not None:
102
+ timestamp = int(timestamp)
103
+ self.client.log_metric(
104
+ run_id=self.run_id,
105
+ key=key,
106
+ value=value,
107
+ timestamp=timestamp,
108
+ step=step,
109
+ )