pycityagent 2.0.0a42__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +23 -0
- pycityagent/agent.py +833 -0
- pycityagent/cli/wrapper.py +44 -0
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +355 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +198 -0
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +416 -0
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +29 -0
- pycityagent/environment/sim/aoi_service.py +39 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +44 -0
- pycityagent/environment/sim/economy_services.py +192 -0
- pycityagent/environment/sim/lane_service.py +111 -0
- pycityagent/environment/sim/light_service.py +122 -0
- pycityagent/environment/sim/person_service.py +295 -0
- pycityagent/environment/sim/road_service.py +39 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/sim/social_service.py +59 -0
- pycityagent/environment/simulator.py +331 -0
- pycityagent/environment/utils/__init__.py +14 -0
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +244 -0
- pycityagent/environment/utils/geojson.py +24 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/environment/utils/protobuf.py +41 -0
- pycityagent/llm/__init__.py +11 -0
- pycityagent/llm/embeddings.py +231 -0
- pycityagent/llm/llm.py +377 -0
- pycityagent/llm/llmconfig.py +13 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +13 -0
- pycityagent/memory/const.py +43 -0
- pycityagent/memory/faiss_query.py +302 -0
- pycityagent/memory/memory.py +448 -0
- pycityagent/memory/memory_base.py +170 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +28 -0
- pycityagent/message/__init__.py +3 -0
- pycityagent/message/messager.py +88 -0
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +8 -0
- pycityagent/simulation/agentgroup.py +580 -0
- pycityagent/simulation/simulation.py +634 -0
- pycityagent/simulation/storage/pg.py +184 -0
- pycityagent/survey/__init__.py +4 -0
- pycityagent/survey/manager.py +54 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +11 -0
- pycityagent/utils/avro_schema.py +109 -0
- pycityagent/utils/decorators.py +99 -0
- pycityagent/utils/parsers/__init__.py +13 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/utils/pg_query.py +92 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/__init__.py +26 -0
- pycityagent/workflow/block.py +211 -0
- pycityagent/workflow/prompt.py +79 -0
- pycityagent/workflow/tool.py +240 -0
- pycityagent/workflow/trigger.py +163 -0
- pycityagent-2.0.0a42.dist-info/LICENSE +21 -0
- pycityagent-2.0.0a42.dist-info/METADATA +235 -0
- pycityagent-2.0.0a42.dist-info/RECORD +81 -0
- pycityagent-2.0.0a42.dist-info/WHEEL +5 -0
- pycityagent-2.0.0a42.dist-info/entry_points.txt +3 -0
- pycityagent-2.0.0a42.dist-info/top_level.txt +3 -0
    
        pycityagent/llm/llm.py
    ADDED
    
    | @@ -0,0 +1,377 @@ | |
| 1 | 
            +
            """UrbanLLM: 智能能力类及其定义"""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            from openai import OpenAI, AsyncOpenAI, APIConnectionError, OpenAIError
         | 
| 5 | 
            +
            from zhipuai import ZhipuAI
         | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            logging.getLogger("zhipuai").setLevel(logging.WARNING)
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import asyncio
         | 
| 11 | 
            +
            from http import HTTPStatus
         | 
| 12 | 
            +
            import dashscope
         | 
| 13 | 
            +
            import requests
         | 
| 14 | 
            +
            from dashscope import ImageSynthesis
         | 
| 15 | 
            +
            from PIL import Image
         | 
| 16 | 
            +
            from io import BytesIO
         | 
| 17 | 
            +
            from typing import Any, Optional, Union
         | 
| 18 | 
            +
            from .llmconfig import *
         | 
| 19 | 
            +
            from .utils import *
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import os
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            os.environ["GRPC_VERBOSITY"] = "ERROR"
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class LLM:
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                大语言模型对象
         | 
| 29 | 
            +
                The LLM Object used by Agent(Soul)
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def __init__(self, config: LLMConfig) -> None:
         | 
| 33 | 
            +
                    self.config = config
         | 
| 34 | 
            +
                    if config.text["request_type"] not in ["openai", "deepseek", "qwen", "zhipuai"]:
         | 
| 35 | 
            +
                        raise ValueError("Invalid request type for text request")
         | 
| 36 | 
            +
                    self.prompt_tokens_used = 0
         | 
| 37 | 
            +
                    self.completion_tokens_used = 0
         | 
| 38 | 
            +
                    self.request_number = 0
         | 
| 39 | 
            +
                    self.semaphore = None
         | 
| 40 | 
            +
                    self._current_client_index = 0
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                    api_keys = self.config.text["api_key"]
         | 
| 43 | 
            +
                    if not isinstance(api_keys, list):
         | 
| 44 | 
            +
                        api_keys = [api_keys]
         | 
| 45 | 
            +
                        
         | 
| 46 | 
            +
                    self._aclients = []
         | 
| 47 | 
            +
                    
         | 
| 48 | 
            +
                    for api_key in api_keys:
         | 
| 49 | 
            +
                        if self.config.text["request_type"] == "openai":
         | 
| 50 | 
            +
                            client = AsyncOpenAI(api_key=api_key, timeout=300)
         | 
| 51 | 
            +
                        elif self.config.text["request_type"] == "deepseek":
         | 
| 52 | 
            +
                            client = AsyncOpenAI(
         | 
| 53 | 
            +
                                api_key=api_key,
         | 
| 54 | 
            +
                                base_url="https://api.deepseek.com/v1",
         | 
| 55 | 
            +
                                timeout=300,
         | 
| 56 | 
            +
                            )
         | 
| 57 | 
            +
                        elif self.config.text["request_type"] == "qwen":
         | 
| 58 | 
            +
                            client = AsyncOpenAI(
         | 
| 59 | 
            +
                                api_key=api_key,
         | 
| 60 | 
            +
                                base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
         | 
| 61 | 
            +
                                timeout=300,
         | 
| 62 | 
            +
                            )
         | 
| 63 | 
            +
                        elif self.config.text["request_type"] == "zhipuai":
         | 
| 64 | 
            +
                            client = ZhipuAI(api_key=api_key, timeout=300)
         | 
| 65 | 
            +
                        self._aclients.append(client)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def set_semaphore(self, number_of_coroutine: int):
         | 
| 68 | 
            +
                    self.semaphore = asyncio.Semaphore(number_of_coroutine)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def clear_semaphore(self):
         | 
| 71 | 
            +
                    self.semaphore = None
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def clear_used(self):
         | 
| 74 | 
            +
                    """
         | 
| 75 | 
            +
                    clear the storage of used tokens to start a new log message
         | 
| 76 | 
            +
                    Only support OpenAI category API right now, including OpenAI, Deepseek
         | 
| 77 | 
            +
                    """
         | 
| 78 | 
            +
                    self.prompt_tokens_used = 0
         | 
| 79 | 
            +
                    self.completion_tokens_used = 0
         | 
| 80 | 
            +
                    self.request_number = 0
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def show_consumption(
         | 
| 83 | 
            +
                    self, input_price: Optional[float] = None, output_price: Optional[float] = None
         | 
| 84 | 
            +
                ):
         | 
| 85 | 
            +
                    """
         | 
| 86 | 
            +
                    if you give the input and output price of using model, this function will also calculate the consumption for you
         | 
| 87 | 
            +
                    """
         | 
| 88 | 
            +
                    total_token = self.prompt_tokens_used + self.completion_tokens_used
         | 
| 89 | 
            +
                    if self.completion_tokens_used != 0:
         | 
| 90 | 
            +
                        rate = self.prompt_tokens_used / self.completion_tokens_used
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        rate = "nan"
         | 
| 93 | 
            +
                    if self.request_number != 0:
         | 
| 94 | 
            +
                        TcA = total_token / self.request_number
         | 
| 95 | 
            +
                    else:
         | 
| 96 | 
            +
                        TcA = "nan"
         | 
| 97 | 
            +
                    out = f"""Request Number: {self.request_number}
         | 
| 98 | 
            +
            Token Usage:
         | 
| 99 | 
            +
                - Total tokens: {total_token}
         | 
| 100 | 
            +
                - Prompt tokens: {self.prompt_tokens_used}
         | 
| 101 | 
            +
                - Completion tokens: {self.completion_tokens_used}
         | 
| 102 | 
            +
                - Token per request: {TcA}
         | 
| 103 | 
            +
                - Prompt:Completion ratio: {rate}:1"""
         | 
| 104 | 
            +
                    if input_price != None and output_price != None:
         | 
| 105 | 
            +
                        consumption = (
         | 
| 106 | 
            +
                            self.prompt_tokens_used / 1000000 * input_price
         | 
| 107 | 
            +
                            + self.completion_tokens_used / 1000000 * output_price
         | 
| 108 | 
            +
                        )
         | 
| 109 | 
            +
                        out += f"\n    - Cost Estimation: {consumption}"
         | 
| 110 | 
            +
                    print(out)
         | 
| 111 | 
            +
                    return {
         | 
| 112 | 
            +
                        "total": total_token,
         | 
| 113 | 
            +
                        "prompt": self.prompt_tokens_used,
         | 
| 114 | 
            +
                        "completion": self.completion_tokens_used,
         | 
| 115 | 
            +
                        "ratio": rate,
         | 
| 116 | 
            +
                    }
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def _get_next_client(self):
         | 
| 119 | 
            +
                    """获取下一个要使用的客户端"""
         | 
| 120 | 
            +
                    client = self._aclients[self._current_client_index]
         | 
| 121 | 
            +
                    self._current_client_index = (self._current_client_index + 1) % len(self._aclients)
         | 
| 122 | 
            +
                    return client
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                async def atext_request(
         | 
| 125 | 
            +
                    self,
         | 
| 126 | 
            +
                    dialog: Any,
         | 
| 127 | 
            +
                    temperature: float = 1,
         | 
| 128 | 
            +
                    max_tokens: Optional[int] = None,
         | 
| 129 | 
            +
                    top_p: Optional[float] = None,
         | 
| 130 | 
            +
                    frequency_penalty: Optional[float] = None,
         | 
| 131 | 
            +
                    presence_penalty: Optional[float] = None,
         | 
| 132 | 
            +
                    timeout: int = 300,
         | 
| 133 | 
            +
                    retries=3,
         | 
| 134 | 
            +
                    tools: Optional[list[dict[str, Any]]] = None,
         | 
| 135 | 
            +
                    tool_choice: Optional[dict[str, Any]] = None,
         | 
| 136 | 
            +
                ):
         | 
| 137 | 
            +
                    """
         | 
| 138 | 
            +
                    异步版文本请求
         | 
| 139 | 
            +
                    """
         | 
| 140 | 
            +
                    if (
         | 
| 141 | 
            +
                        self.config.text["request_type"] == "openai"
         | 
| 142 | 
            +
                        or self.config.text["request_type"] == "deepseek"
         | 
| 143 | 
            +
                        or self.config.text["request_type"] == "qwen"
         | 
| 144 | 
            +
                    ):
         | 
| 145 | 
            +
                        for attempt in range(retries):
         | 
| 146 | 
            +
                            try:
         | 
| 147 | 
            +
                                client = self._get_next_client()
         | 
| 148 | 
            +
                                if self.semaphore != None:
         | 
| 149 | 
            +
                                    async with self.semaphore:
         | 
| 150 | 
            +
                                        response = await client.chat.completions.create(
         | 
| 151 | 
            +
                                            model=self.config.text["model"],
         | 
| 152 | 
            +
                                            messages=dialog,
         | 
| 153 | 
            +
                                            temperature=temperature,
         | 
| 154 | 
            +
                                            max_tokens=max_tokens,
         | 
| 155 | 
            +
                                            top_p=top_p,
         | 
| 156 | 
            +
                                            frequency_penalty=frequency_penalty,  # type: ignore
         | 
| 157 | 
            +
                                            presence_penalty=presence_penalty,  # type: ignore
         | 
| 158 | 
            +
                                            stream=False,
         | 
| 159 | 
            +
                                            timeout=timeout,
         | 
| 160 | 
            +
                                            tools=tools,
         | 
| 161 | 
            +
                                            tool_choice=tool_choice,
         | 
| 162 | 
            +
                                        )  # type: ignore
         | 
| 163 | 
            +
                                        self.prompt_tokens_used += response.usage.prompt_tokens  # type: ignore
         | 
| 164 | 
            +
                                        self.completion_tokens_used += response.usage.completion_tokens  # type: ignore
         | 
| 165 | 
            +
                                        self.request_number += 1
         | 
| 166 | 
            +
                                        if tools and response.choices[0].message.tool_calls:
         | 
| 167 | 
            +
                                            return json.loads(
         | 
| 168 | 
            +
                                                response.choices[0]
         | 
| 169 | 
            +
                                                .message.tool_calls[0]
         | 
| 170 | 
            +
                                                .function.arguments
         | 
| 171 | 
            +
                                            )
         | 
| 172 | 
            +
                                        else:
         | 
| 173 | 
            +
                                            return response.choices[0].message.content
         | 
| 174 | 
            +
                                else:
         | 
| 175 | 
            +
                                    response = await client.chat.completions.create(
         | 
| 176 | 
            +
                                        model=self.config.text["model"],
         | 
| 177 | 
            +
                                        messages=dialog,
         | 
| 178 | 
            +
                                        temperature=temperature,
         | 
| 179 | 
            +
                                        max_tokens=max_tokens,
         | 
| 180 | 
            +
                                        top_p=top_p,
         | 
| 181 | 
            +
                                        frequency_penalty=frequency_penalty,  # type: ignore
         | 
| 182 | 
            +
                                        presence_penalty=presence_penalty,  # type: ignore
         | 
| 183 | 
            +
                                        stream=False,
         | 
| 184 | 
            +
                                        timeout=timeout,
         | 
| 185 | 
            +
                                        tools=tools,
         | 
| 186 | 
            +
                                        tool_choice=tool_choice,
         | 
| 187 | 
            +
                                    )  # type: ignore
         | 
| 188 | 
            +
                                    self.prompt_tokens_used += response.usage.prompt_tokens  # type: ignore
         | 
| 189 | 
            +
                                    self.completion_tokens_used += response.usage.completion_tokens  # type: ignore
         | 
| 190 | 
            +
                                    self.request_number += 1
         | 
| 191 | 
            +
                                    if tools and response.choices[0].message.tool_calls:
         | 
| 192 | 
            +
                                        return json.loads(
         | 
| 193 | 
            +
                                            response.choices[0]
         | 
| 194 | 
            +
                                            .message.tool_calls[0]
         | 
| 195 | 
            +
                                            .function.arguments
         | 
| 196 | 
            +
                                        )
         | 
| 197 | 
            +
                                    else:
         | 
| 198 | 
            +
                                        return response.choices[0].message.content
         | 
| 199 | 
            +
                            except APIConnectionError as e:
         | 
| 200 | 
            +
                                print("API connection error:", e)
         | 
| 201 | 
            +
                                if attempt < retries - 1:
         | 
| 202 | 
            +
                                    await asyncio.sleep(2**attempt)
         | 
| 203 | 
            +
                                else:
         | 
| 204 | 
            +
                                    raise e
         | 
| 205 | 
            +
                            except OpenAIError as e:
         | 
| 206 | 
            +
                                if hasattr(e, "http_status"):
         | 
| 207 | 
            +
                                    print(f"HTTP status code: {e.http_status}")  # type: ignore
         | 
| 208 | 
            +
                                else:
         | 
| 209 | 
            +
                                    print("An error occurred:", e)
         | 
| 210 | 
            +
                                if attempt < retries - 1:
         | 
| 211 | 
            +
                                    await asyncio.sleep(2**attempt)
         | 
| 212 | 
            +
                                else:
         | 
| 213 | 
            +
                                    raise e
         | 
| 214 | 
            +
                    elif self.config.text["request_type"] == "zhipuai":
         | 
| 215 | 
            +
                        for attempt in range(retries):
         | 
| 216 | 
            +
                            try:
         | 
| 217 | 
            +
                                client = self._get_next_client()
         | 
| 218 | 
            +
                                response = client.chat.asyncCompletions.create(  # type: ignore
         | 
| 219 | 
            +
                                    model=self.config.text["model"],
         | 
| 220 | 
            +
                                    messages=dialog,
         | 
| 221 | 
            +
                                    temperature=temperature,
         | 
| 222 | 
            +
                                    top_p=top_p,
         | 
| 223 | 
            +
                                    timeout=timeout,
         | 
| 224 | 
            +
                                    tools=tools,
         | 
| 225 | 
            +
                                    tool_choice=tool_choice,
         | 
| 226 | 
            +
                                )
         | 
| 227 | 
            +
                                task_id = response.id
         | 
| 228 | 
            +
                                task_status = ""
         | 
| 229 | 
            +
                                get_cnt = 0
         | 
| 230 | 
            +
                                cnt_threshold = int(timeout / 0.5)
         | 
| 231 | 
            +
                                while (
         | 
| 232 | 
            +
                                    task_status != "SUCCESS"
         | 
| 233 | 
            +
                                    and task_status != "FAILED"
         | 
| 234 | 
            +
                                    and get_cnt <= cnt_threshold
         | 
| 235 | 
            +
                                ):
         | 
| 236 | 
            +
                                    result_response = client.chat.asyncCompletions.retrieve_completion_result(id=task_id)  # type: ignore
         | 
| 237 | 
            +
                                    task_status = result_response.task_status
         | 
| 238 | 
            +
                                    await asyncio.sleep(0.5)
         | 
| 239 | 
            +
                                    get_cnt += 1
         | 
| 240 | 
            +
                                if task_status != "SUCCESS":
         | 
| 241 | 
            +
                                    raise Exception(f"Task failed with status: {task_status}")
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                                self.prompt_tokens_used += result_response.usage.prompt_tokens  # type: ignore
         | 
| 244 | 
            +
                                self.completion_tokens_used += result_response.usage.completion_tokens  # type: ignore
         | 
| 245 | 
            +
                                self.request_number += 1
         | 
| 246 | 
            +
                                if tools and result_response.choices[0].message.tool_calls:  # type: ignore
         | 
| 247 | 
            +
                                    return json.loads(
         | 
| 248 | 
            +
                                        result_response.choices[0]  # type: ignore
         | 
| 249 | 
            +
                                        .message.tool_calls[0]
         | 
| 250 | 
            +
                                        .function.arguments
         | 
| 251 | 
            +
                                    )
         | 
| 252 | 
            +
                                else:
         | 
| 253 | 
            +
                                    return result_response.choices[0].message.content  # type: ignore
         | 
| 254 | 
            +
                            except APIConnectionError as e:
         | 
| 255 | 
            +
                                print("API connection error:", e)
         | 
| 256 | 
            +
                                if attempt < retries - 1:
         | 
| 257 | 
            +
                                    await asyncio.sleep(2**attempt)
         | 
| 258 | 
            +
                                else:
         | 
| 259 | 
            +
                                    raise e
         | 
| 260 | 
            +
                    else:
         | 
| 261 | 
            +
                        print("ERROR: Wrong Config")
         | 
| 262 | 
            +
                        return "wrong config"
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                async def img_understand(
         | 
| 265 | 
            +
                    self, img_path: Union[str, list[str]], prompt: Optional[str] = None
         | 
| 266 | 
            +
                ) -> str:
         | 
| 267 | 
            +
                    """
         | 
| 268 | 
            +
                    图像理解
         | 
| 269 | 
            +
                    Image understanding
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    Args:
         | 
| 272 | 
            +
                    - img_path (Union[str, list[str]]): 目标图像的路径, 既可以是一个路径也可以是包含多张图片路径的list. The path of selected Image
         | 
| 273 | 
            +
                    - prompt (str): 理解提示词 - 例如理解方向. The understanding prompts
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    Returns:
         | 
| 276 | 
            +
                    - (str): the understanding content
         | 
| 277 | 
            +
                    """
         | 
| 278 | 
            +
                    ppt = "如何理解这幅图像?"
         | 
| 279 | 
            +
                    if prompt != None:
         | 
| 280 | 
            +
                        ppt = prompt
         | 
| 281 | 
            +
                    if self.config.image_u["request_type"] == "openai":
         | 
| 282 | 
            +
                        if "api_base" in self.config.image_u.keys():
         | 
| 283 | 
            +
                            api_base = self.config.image_u["api_base"]
         | 
| 284 | 
            +
                        else:
         | 
| 285 | 
            +
                            api_base = None
         | 
| 286 | 
            +
                        client = OpenAI(
         | 
| 287 | 
            +
                            api_key=self.config.text["api_key"],
         | 
| 288 | 
            +
                            base_url=api_base,
         | 
| 289 | 
            +
                        )
         | 
| 290 | 
            +
                        content = []
         | 
| 291 | 
            +
                        content.append({"type": "text", "text": ppt})
         | 
| 292 | 
            +
                        if isinstance(img_path, str):
         | 
| 293 | 
            +
                            base64_image = encode_image(img_path)
         | 
| 294 | 
            +
                            content.append(
         | 
| 295 | 
            +
                                {
         | 
| 296 | 
            +
                                    "type": "image_url",
         | 
| 297 | 
            +
                                    "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
         | 
| 298 | 
            +
                                }
         | 
| 299 | 
            +
                            )
         | 
| 300 | 
            +
                        elif isinstance(img_path, list) and all(
         | 
| 301 | 
            +
                            isinstance(item, str) for item in img_path
         | 
| 302 | 
            +
                        ):
         | 
| 303 | 
            +
                            for item in img_path:
         | 
| 304 | 
            +
                                base64_image = encode_image(item)
         | 
| 305 | 
            +
                                content.append(
         | 
| 306 | 
            +
                                    {
         | 
| 307 | 
            +
                                        "type": "image_url",
         | 
| 308 | 
            +
                                        "image_url": {
         | 
| 309 | 
            +
                                            "url": f"data:image/jpeg;base64,{base64_image}"
         | 
| 310 | 
            +
                                        },
         | 
| 311 | 
            +
                                    }
         | 
| 312 | 
            +
                                )
         | 
| 313 | 
            +
                        response = client.chat.completions.create(
         | 
| 314 | 
            +
                            model=self.config.image_u["model"],
         | 
| 315 | 
            +
                            messages=[{"role": "user", "content": content}],
         | 
| 316 | 
            +
                        )
         | 
| 317 | 
            +
                        return response.choices[0].message.content  # type: ignore
         | 
| 318 | 
            +
                    elif self.config.image_u["request_type"] == "qwen":
         | 
| 319 | 
            +
                        content = []
         | 
| 320 | 
            +
                        if isinstance(img_path, str):
         | 
| 321 | 
            +
                            content.append({"image": "file://" + img_path})
         | 
| 322 | 
            +
                            content.append({"text": ppt})
         | 
| 323 | 
            +
                        elif isinstance(img_path, list) and all(
         | 
| 324 | 
            +
                            isinstance(item, str) for item in img_path
         | 
| 325 | 
            +
                        ):
         | 
| 326 | 
            +
                            for item in img_path:
         | 
| 327 | 
            +
                                content.append({"image": "file://" + item})
         | 
| 328 | 
            +
                            content.append({"text": ppt})
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                        dialog = [{"role": "user", "content": content}]
         | 
| 331 | 
            +
                        response = dashscope.MultiModalConversation.call(
         | 
| 332 | 
            +
                            model=self.config.image_u["model"],
         | 
| 333 | 
            +
                            api_key=self.config.image_u["api_key"],
         | 
| 334 | 
            +
                            messages=dialog,
         | 
| 335 | 
            +
                        )
         | 
| 336 | 
            +
                        if response.status_code == HTTPStatus.OK:  # type: ignore
         | 
| 337 | 
            +
                            return response.output.choices[0]["message"]["content"]  # type: ignore
         | 
| 338 | 
            +
                        else:
         | 
| 339 | 
            +
                            print(response.code)  # type: ignore # The error code.
         | 
| 340 | 
            +
                            return "Error"
         | 
| 341 | 
            +
                    else:
         | 
| 342 | 
            +
                        print(
         | 
| 343 | 
            +
                            "ERROR: wrong image understanding type, only 'openai' and 'openai' is available"
         | 
| 344 | 
            +
                        )
         | 
| 345 | 
            +
                        return "Error"
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                async def img_generate(self, prompt: str, size: str = "512*512", quantity: int = 1):
         | 
| 348 | 
            +
                    """
         | 
| 349 | 
            +
                    图像生成
         | 
| 350 | 
            +
                    Image generation
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    Args:
         | 
| 353 | 
            +
                    - prompt (str): 图像生成提示词. The image generation prompts
         | 
| 354 | 
            +
                    - size (str): 生成图像尺寸, 默认为'512*512'. The image size, default: '512*512'
         | 
| 355 | 
            +
                    - quantity (int): 生成图像数量, 默认为1. The quantity of generated images, default: 1
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                    Returns:
         | 
| 358 | 
            +
                    - (list[PIL.Image.Image]): 生成的图像列表. The list of generated Images.
         | 
| 359 | 
            +
                    """
         | 
| 360 | 
            +
                    rsp = ImageSynthesis.call(
         | 
| 361 | 
            +
                        model=self.config.image_g["model"],
         | 
| 362 | 
            +
                        api_key=self.config.image_g["api_key"],
         | 
| 363 | 
            +
                        prompt=prompt,
         | 
| 364 | 
            +
                        n=quantity,
         | 
| 365 | 
            +
                        size=size,
         | 
| 366 | 
            +
                    )
         | 
| 367 | 
            +
                    if rsp.status_code == HTTPStatus.OK:
         | 
| 368 | 
            +
                        res = []
         | 
| 369 | 
            +
                        for result in rsp.output.results:
         | 
| 370 | 
            +
                            res.append(Image.open(BytesIO(requests.get(result.url).content)))
         | 
| 371 | 
            +
                        return res
         | 
| 372 | 
            +
                    else:
         | 
| 373 | 
            +
                        print(
         | 
| 374 | 
            +
                            "Failed, status_code: %s, code: %s, message: %s"
         | 
| 375 | 
            +
                            % (rsp.status_code, rsp.code, rsp.message)
         | 
| 376 | 
            +
                        )
         | 
| 377 | 
            +
                        return None
         | 
| @@ -0,0 +1,13 @@ | |
| 1 | 
            +
            class LLMConfig:
         | 
| 2 | 
            +
                """
         | 
| 3 | 
            +
                大语言模型相关配置
         | 
| 4 | 
            +
                The config of LLM
         | 
| 5 | 
            +
                """
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                def __init__(self, config: dict) -> None:
         | 
| 8 | 
            +
                    self.config = config
         | 
| 9 | 
            +
                    self.text = config["text_request"]
         | 
| 10 | 
            +
                    if "api_base" in self.text.keys() and self.text["api_base"] == "None":
         | 
| 11 | 
            +
                        self.text["api_base"] = None
         | 
| 12 | 
            +
                    self.image_u = config["img_understand_request"]
         | 
| 13 | 
            +
                    self.image_g = config["img_generate_request"]
         | 
    
        pycityagent/llm/utils.py
    ADDED
    
    
| @@ -0,0 +1,13 @@ | |
| 1 | 
            +
            """Memory."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from .faiss_query import FaissQuery
         | 
| 4 | 
            +
            from .memory import Memory
         | 
| 5 | 
            +
            from .memory_base import MemoryBase, MemoryUnit
         | 
| 6 | 
            +
            from .profile import ProfileMemory, ProfileMemoryUnit
         | 
| 7 | 
            +
            from .self_define import DynamicMemory
         | 
| 8 | 
            +
            from .state import StateMemory
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            __all__ = [
         | 
| 11 | 
            +
                "Memory",
         | 
| 12 | 
            +
                "FaissQuery",
         | 
| 13 | 
            +
            ]
         | 
| @@ -0,0 +1,43 @@ | |
| 1 | 
            +
            from pycityproto.city.person.v2.motion_pb2 import Status
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            PROFILE_ATTRIBUTES = {
         | 
| 4 | 
            +
                "name": str(),
         | 
| 5 | 
            +
                "gender": str(),
         | 
| 6 | 
            +
                "age": float(),
         | 
| 7 | 
            +
                "education": str(),
         | 
| 8 | 
            +
                "skill": str(),
         | 
| 9 | 
            +
                "occupation": str(),
         | 
| 10 | 
            +
                "family_consumption": str(),
         | 
| 11 | 
            +
                "consumption": str(),
         | 
| 12 | 
            +
                "personality": str(),
         | 
| 13 | 
            +
                "income": str(),
         | 
| 14 | 
            +
                "currency": float(),
         | 
| 15 | 
            +
                "residence": str(),
         | 
| 16 | 
            +
                "race": str(),
         | 
| 17 | 
            +
                "religion": str(),
         | 
| 18 | 
            +
                "marital_status": str(),
         | 
| 19 | 
            +
            }
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            STATE_ATTRIBUTES = {
         | 
| 22 | 
            +
                # base
         | 
| 23 | 
            +
                "id": -1,
         | 
| 24 | 
            +
                "attribute": dict(),
         | 
| 25 | 
            +
                "home": dict(),
         | 
| 26 | 
            +
                "work": dict(),
         | 
| 27 | 
            +
                "schedules": [],
         | 
| 28 | 
            +
                "vehicle_attribute": dict(),
         | 
| 29 | 
            +
                "bus_attribute": dict(),
         | 
| 30 | 
            +
                "pedestrian_attribute": dict(),
         | 
| 31 | 
            +
                "bike_attribute": dict(),
         | 
| 32 | 
            +
                # motion
         | 
| 33 | 
            +
                "status": Status.STATUS_UNSPECIFIED,
         | 
| 34 | 
            +
                "position": dict(),
         | 
| 35 | 
            +
                "v": float(),
         | 
| 36 | 
            +
                "direction": float(),
         | 
| 37 | 
            +
                "activity": str(),
         | 
| 38 | 
            +
                "l": float(),
         | 
| 39 | 
            +
            }
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            SELF_DEFINE_PREFIX = "self_define_"
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            TIME_STAMP_KEY = "_timestamp"
         |