pycityagent 1.1.9__py3-none-any.whl → 1.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycityagent/simulator.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """Simulator: 城市模拟器类及其定义"""
2
2
 
3
- from typing import Optional, Union
3
+ from typing import Optional, Union, Tuple
4
4
  from datetime import datetime, timedelta
5
5
  import asyncio
6
6
  from pycitysim import *
@@ -58,7 +58,7 @@ class Simulator:
58
58
 
59
59
  self.map = map.Map(
60
60
  mongo_uri = "mongodb://sim:FiblabSim1001@mgo.db.fiblab.tech:8635/",
61
- mongo_db = "srt",
61
+ mongo_db = config['map_request']['mongo_db'],
62
62
  mongo_coll = config['map_request']['mongo_coll'],
63
63
  cache_dir = config['map_request']['cache_dir'],
64
64
  )
@@ -67,6 +67,12 @@ class Simulator:
67
67
  - Simulator map object
68
68
  """
69
69
 
70
+ self.pois_matrix: dict[str, list[list[list]]] = {}
71
+ """
72
+ pois的基于区块的划分——方便快速粗略地查询poi
73
+ 通过Simulator.set_pois_matrix()初始化
74
+ """
75
+
70
76
  self.routing = RoutingClient(self.config['route_request']['server'])
71
77
  """
72
78
  - 导航服务grpc客户端
@@ -78,6 +84,20 @@ class Simulator:
78
84
  - 模拟城市当前时间
79
85
  - The current time of simulator
80
86
  """
87
+ self.poi_cate = {'10': 'eat',
88
+ '13': 'shopping',
89
+ '18': 'sports',
90
+ '22': 'excursion',
91
+ '16': 'entertainment',
92
+ '20': 'medical tratment',
93
+ '14': 'trivialities',
94
+ '25': 'financial',
95
+ '12': 'government and political services',
96
+ '23': 'cultural institutions',
97
+ '28': 'residence'}
98
+ self.map_x_gap = None
99
+ self.map_y_gap = None
100
+ self.poi_matrix_centers = []
81
101
 
82
102
  # * Agent相关
83
103
  def FindAgentsByArea(self, req: dict, status=None):
@@ -105,7 +125,7 @@ class Simulator:
105
125
  resp.motions = motions
106
126
  return resp
107
127
 
108
- async def GetCitizenAgent(self, name:str, id:int):
128
+ async def GetCitizenAgent(self, name:str, id:int) -> CitizenAgent:
109
129
  """
110
130
  获取agent
111
131
  Get Agent
@@ -119,8 +139,9 @@ class Simulator:
119
139
  """
120
140
  await self.GetTime()
121
141
  resp = await self._client.person_service.GetPerson({"person_id": id})
122
- base = resp['base']
123
- motion = resp['motion']
142
+ print(f"Agent {id}: {resp}")
143
+ base = resp['person']['base']
144
+ motion = resp['person']['motion']
124
145
  agent = CitizenAgent(
125
146
  name,
126
147
  self.config['simulator']['server'],
@@ -132,7 +153,7 @@ class Simulator:
132
153
  agent.set_streetview_config(self.config['streetview_request'])
133
154
  return agent
134
155
 
135
- async def GetFuncAgent(self, id:int, name:str):
156
+ async def GetFuncAgent(self, name:str) -> FuncAgent:
136
157
  """
137
158
  获取一个Func Agent模板
138
159
 
@@ -145,7 +166,6 @@ class Simulator:
145
166
  """
146
167
  agent = FuncAgent(
147
168
  name,
148
- id+10000000,
149
169
  self.config['simulator']['server'],
150
170
  simulator=self
151
171
  )
@@ -160,6 +180,128 @@ class Simulator:
160
180
  """
161
181
  print("Not Implemented Yet")
162
182
  pass
183
+
184
+ def set_poi_matrix(self, map:dict=None, row_number:int=12, col_number:int=10, radius:int=10000):
185
+ """
186
+ 初始化pois_matrix
187
+
188
+ Args:
189
+ - map (dict): 地图参数
190
+ east, west, north, south
191
+ - row_number (int): 行数
192
+ - col_number (int): 列数
193
+ - radius (int): 搜索半径, 单位m
194
+ """
195
+ if map == None:
196
+ self.matrix_map = self.map
197
+ else:
198
+ self.matrix_map = map
199
+ print(f"Building Poi searching matrix, Row_number: {row_number}, Col_number: {col_number}, Radius: {radius}m")
200
+ self.map_x_gap = (self.matrix_map.header['east'] - self.matrix_map.header['west']) / col_number
201
+ self.map_y_gap = (self.matrix_map.header['north'] - self.matrix_map.header['south']) / row_number
202
+ for i in range(row_number):
203
+ self.poi_matrix_centers.append([])
204
+ for j in range(col_number):
205
+ center_x = self.matrix_map.header['west'] + self.map_x_gap*j + self.map_x_gap/2
206
+ center_y = self.matrix_map.header['south'] + self.map_y_gap*i + self.map_y_gap/2
207
+ self.poi_matrix_centers[i].append((center_x, center_y))
208
+
209
+ for pre in self.poi_cate.keys():
210
+ print(f"Building matrix for Poi category: {pre}")
211
+ self.pois_matrix[pre] = []
212
+ for row_centers in self.poi_matrix_centers:
213
+ row_pois = []
214
+ for center in row_centers:
215
+ pois = self.map.query_pois(center=center, radius=radius, category_prefix=pre)
216
+ row_pois.append(pois)
217
+ self.pois_matrix[pre].append(row_pois)
218
+ print("Finished")
219
+
220
+ def get_pois_from_matrix(self, center:Tuple[float, float], prefix:str):
221
+ """
222
+ 从poi搜索矩阵中快速获取poi
223
+
224
+ Args:
225
+ - center (Tuple[float, float]): 位置信息
226
+ - prefix (str): 类型前缀
227
+ """
228
+ if self.map_x_gap == None:
229
+ print("Set Poi Matrix first")
230
+ return
231
+ elif prefix not in self.poi_cate.keys():
232
+ print(f"Wrong prefix, only {self.poi_cate.keys()} is usable")
233
+ return
234
+ elif center[0] > self.matrix_map.header['east'] or center[0] < self.matrix_map.header['west'] or center[1] > self.matrix_map.header['north'] or center[1] < self.matrix_map.header['south']:
235
+ print("Wrong center")
236
+ return
237
+
238
+ # 矩阵匹配
239
+ rows = int((center[1]-self.matrix_map.header['south'])/self.map_y_gap)
240
+ cols = int((center[0]-self.matrix_map.header['west'])/self.map_x_gap)
241
+ pois = self.pois_matrix[prefix][rows][cols]
242
+ return pois
243
+
244
+ def get_cat_from_pois(self, pois:list):
245
+ cat_2_num = {}
246
+ for poi in pois:
247
+ cate = poi['category'][:2]
248
+ if cate not in self.poi_cate.keys():
249
+ continue
250
+ if cate in cat_2_num.keys():
251
+ cat_2_num[cate] += 1
252
+ else:
253
+ cat_2_num[cate] = 1
254
+ max_cat = ""
255
+ max_num = 0
256
+ for key in cat_2_num.keys():
257
+ if cat_2_num[key] > max_num:
258
+ max_num = cat_2_num[key]
259
+ max_cat = self.poi_cate[key]
260
+ return max_cat
261
+
262
+ def get_poi_matrix_in_rec(self, center:Tuple[float, float], radius:int=2500, rows:int=5, cols:int=5):
263
+ """
264
+ 获取以center为中心的正方形区域内的poi集合
265
+
266
+ Args:
267
+ - center (Tuple[float, float]): 中心位置点
268
+ - radius (int): 半径
269
+ """
270
+ north = center[1] + radius
271
+ south = center[1] - radius
272
+ west = center[0] - radius
273
+ east = center[0] + radius
274
+ x_gap = (east-west)/cols
275
+ y_gap = (north-south)/rows
276
+ matrix = []
277
+ for i in range(rows):
278
+ matrix.append([])
279
+ for j in range(cols):
280
+ matrix[i].append([])
281
+ pois = []
282
+ for poi in self.map.pois.values():
283
+ x = poi['position']['x']
284
+ y = poi['position']['y']
285
+ if x > west and x < east and y > south and y < north:
286
+ row_index = int((y-south)/x_gap)
287
+ col_index = int((x-west)/y_gap)
288
+ matrix[row_index][col_index].append(poi)
289
+ matrix_type = []
290
+ for i in range(rows):
291
+ for j in range(cols):
292
+ matrix_type.append(self.get_cat_from_pois(matrix[i][j]))
293
+ poi_total_number = []
294
+ poi_type_number = []
295
+ for i in range(rows):
296
+ for j in range(cols):
297
+ poi_total_number.append(len(matrix[i][j]))
298
+ number = 0
299
+ for poi in matrix[i][j]:
300
+ if poi['category'][:2] in self.poi_cate.keys() and self.poi_cate[poi['category'][:2]] == matrix_type[i*cols+j]:
301
+ number += 1
302
+ poi_type_number.append(number)
303
+
304
+ return matrix, matrix_type, poi_total_number, poi_type_number
163
305
 
164
306
  async def GetTime(self, format_time:bool=False, format:Optional[str]="%H:%M:%S") -> Union[int, str]:
165
307
  """
@@ -1,6 +1,8 @@
1
1
  """UrbanLLM: 智能能力类及其定义"""
2
2
 
3
- from openai import OpenAI
3
+ from openai import OpenAI, AsyncOpenAI, APIConnectionError, OpenAIError
4
+ import openai
5
+ import asyncio
4
6
  from http import HTTPStatus
5
7
  import dashscope
6
8
  import requests
@@ -9,10 +11,55 @@ from PIL import Image
9
11
  from io import BytesIO
10
12
  from typing import Union
11
13
  import base64
14
+ import aiohttp
15
+ import re
12
16
 
13
17
  def encode_image(image_path):
14
18
  with open(image_path, "rb") as image_file:
15
19
  return base64.b64encode(image_file.read()).decode('utf-8')
20
+
21
+ class VarSet:
22
+ def __init__(self, variable_batch, variable_description, K):
23
+ self.var_batch = variable_batch
24
+ self.var_des = variable_description
25
+ self.var_size = len(list(variable_description.keys()))
26
+ self.batch_size = K
27
+
28
+ def set_variables(self, variables:list[dict]):
29
+ if len(variables) <= 0 or len(variables) != self.batch_size or (len(variables[0]) != self.var_size):
30
+ print(f"Your input variables mis-match the size of current Variable Set: Batch size-{self.batch_size}, Variable size-{self.var_size}")
31
+ return None
32
+ for i in range(len(self.var_batch)):
33
+ for var, _ in self.var_batch[i].items():
34
+ self.var_batch[i][var] = variables[i][var]
35
+
36
+ def get_variables(self):
37
+ return self.var_batch
38
+
39
+ def get_var_des(self):
40
+ var_des = ""
41
+ for var, desc in self.var_des.items():
42
+ var_des += f"batch[x][{var}] -- {desc}\n"
43
+ return var_des
44
+
45
+ class PromptBatch:
46
+ def __init__(self, prompt_batch:str, variable_batch, vasriable_description, K):
47
+ self.prompt_raw = prompt_batch
48
+ self.prompt_set = None
49
+ self.batch_size = K
50
+ self.var_set = VarSet(variable_batch, vasriable_description, K)
51
+
52
+ def set_variables(self, variables:list[dict]):
53
+ self.var_set.set_variables(variables)
54
+ self.prompt_set = self.prompt_raw.format(batch=self.get_variables())
55
+
56
+ def get_variables(self):
57
+ return self.var_set.var_batch
58
+
59
+ def get_prompt(self):
60
+ if self.prompt_set == None:
61
+ return self.prompt_raw
62
+ return self.prompt_set
16
63
 
17
64
  class LLMConfig:
18
65
  """
@@ -37,6 +84,53 @@ class UrbanLLM:
37
84
  """
38
85
  def __init__(self, config: LLMConfig) -> None:
39
86
  self.config = config
87
+ self.prompt_tokens_used = 0
88
+ self.completion_tokens_used = 0
89
+ self.request_number = 0
90
+ self.semaphore = None
91
+ self._aclient = AsyncOpenAI(api_key=self.config.text['api_key'], timeout=300)
92
+
93
+ def set_semaphore(self, number_of_coroutine:int):
94
+ self.semaphore = asyncio.Semaphore(number_of_coroutine)
95
+
96
+ def clear_semaphore(self):
97
+ self.semaphore = None
98
+
99
+ def clear_used(self):
100
+ """
101
+ clear the storage of used tokens to start a new log message
102
+ Only support OpenAI category API right now, including OpenAI, Deepseek
103
+ """
104
+ self.prompt_tokens_used = 0
105
+ self.completion_tokens_used = 0
106
+ self.request_number = 0
107
+
108
+ def show_consumption(self, input_price:float=None, output_price:float=None):
109
+ """
110
+ if you give the input and output price of using model, this function will also calculate the consumption for you
111
+ """
112
+ total_token = self.prompt_tokens_used + self.completion_tokens_used
113
+ if self.completion_tokens_used != 0:
114
+ rate = self.prompt_tokens_used/self.completion_tokens_used
115
+ else:
116
+ rate = 'nan'
117
+ if self.request_number != 0:
118
+ TcA = total_token/self.request_number
119
+ else:
120
+ TcA = 'nan'
121
+ out = f"""Request Number: {self.request_number}
122
+ Token Usage:
123
+ - Total tokens: {total_token}
124
+ - Prompt tokens: {self.prompt_tokens_used}
125
+ - Completion tokens: {self.completion_tokens_used}
126
+ - Token per request: {TcA}
127
+ - Prompt:Completion ratio: {rate}:1"""
128
+ if input_price != None and output_price != None:
129
+ consumption = self.prompt_tokens_used/1000000*input_price + self.completion_tokens_used/1000000*output_price
130
+ out += f"\n - Cost Estimation: {consumption}"
131
+ print(out)
132
+ return {"total": total_token, "prompt": self.prompt_tokens_used, "completion": self.completion_tokens_used, "ratio": rate}
133
+
40
134
 
41
135
  def text_request(self, dialog:list[dict], temperature:float=1, max_tokens:int=None, top_p:float=None, frequency_penalty:float=None, presence_penalty:float=None) -> str:
42
136
  """
@@ -72,6 +166,9 @@ class UrbanLLM:
72
166
  frequency_penalty=frequency_penalty,
73
167
  presence_penalty=presence_penalty
74
168
  )
169
+ self.prompt_tokens_used += response.usage.prompt_tokens
170
+ self.completion_tokens_used += response.usage.completion_tokens
171
+ self.request_number += 1
75
172
  return response.choices[0].message.content
76
173
  elif self.config.text['request_type'] == 'qwen':
77
174
  response = dashscope.Generation.call(
@@ -83,12 +180,128 @@ class UrbanLLM:
83
180
  if response.status_code == HTTPStatus.OK:
84
181
  return response.output.choices[0]['message']['content']
85
182
  else:
86
- return "Error: {}".format(response.status_code)
183
+ return "Error: {}, {}".format(response.status_code, response.message)
184
+ elif self.config.text['request_type'] == 'deepseek':
185
+ client = OpenAI(
186
+ api_key=self.config.text['api_key'],
187
+ base_url="https://api.deepseek.com/beta",
188
+ )
189
+ response = client.chat.completions.create(
190
+ model=self.config.text['model'],
191
+ messages=dialog,
192
+ temperature=temperature,
193
+ max_tokens=max_tokens,
194
+ top_p=top_p,
195
+ frequency_penalty=frequency_penalty,
196
+ presence_penalty=presence_penalty,
197
+ stream=False,
198
+ )
199
+ self.prompt_tokens_used += response.usage.prompt_tokens
200
+ self.completion_tokens_used += response.usage.completion_tokens
201
+ self.request_number += 1
202
+ return response.choices[0].message.content
87
203
  else:
88
204
  print("ERROR: Wrong Config")
89
205
  return "wrong config"
206
+
207
+ async def atext_request(self, dialog:list[dict], temperature:float=1, max_tokens:int=None, top_p:float=None, frequency_penalty:float=None, presence_penalty:float=None, timeout:int=300, retries=3):
208
+ """
209
+ 异步版文本请求
210
+ """
211
+ if self.config.text['request_type'] == 'openai':
212
+ for attempt in range(retries):
213
+ try:
214
+ if self.semaphore != None:
215
+ async with self.semaphore:
216
+ response = await self._aclient.chat.completions.create(
217
+ model=self.config.text['model'],
218
+ messages=dialog,
219
+ temperature=temperature,
220
+ max_tokens=max_tokens,
221
+ top_p=top_p,
222
+ frequency_penalty=frequency_penalty,
223
+ presence_penalty=presence_penalty,
224
+ stream=False,
225
+ timeout=timeout,
226
+ )
227
+ self.prompt_tokens_used += response.usage.prompt_tokens
228
+ self.completion_tokens_used += response.usage.completion_tokens
229
+ self.request_number += 1
230
+ return response.choices[0].message.content
231
+ else:
232
+ response = await self._aclient.chat.completions.create(
233
+ model=self.config.text['model'],
234
+ messages=dialog,
235
+ temperature=temperature,
236
+ max_tokens=max_tokens,
237
+ top_p=top_p,
238
+ frequency_penalty=frequency_penalty,
239
+ presence_penalty=presence_penalty,
240
+ stream=False,
241
+ timeout=timeout,
242
+ )
243
+ self.prompt_tokens_used += response.usage.prompt_tokens
244
+ self.completion_tokens_used += response.usage.completion_tokens
245
+ self.request_number += 1
246
+ return response.choices[0].message.content
247
+ except APIConnectionError as e:
248
+ print("API connection error:", e)
249
+ if attempt < retries - 1:
250
+ await asyncio.sleep(2 ** attempt)
251
+ else:
252
+ raise e
253
+ except OpenAIError as e:
254
+ if hasattr(e, 'http_status'):
255
+ print(f"HTTP status code: {e.http_status}")
256
+ else:
257
+ print("An error occurred:", e)
258
+ if attempt < retries - 1:
259
+ await asyncio.sleep(2 ** attempt)
260
+ else:
261
+ raise e
262
+ elif self.config.text['request_type'] == 'qwen':
263
+ async with aiohttp.ClientSession() as session:
264
+ api_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
265
+ headers = {"Content-Type": "application/json", "Authorization": f"{self.config.text['api_key']}"}
266
+ payload = {
267
+ 'model': self.config.text['model'],
268
+ 'input': {
269
+ 'messages': dialog
270
+ }
271
+ }
272
+ async with session.post(api_url, json=payload, headers=headers) as resp:
273
+ # 错误检查
274
+ response_json = await resp.json()
275
+ if 'code' in response_json.keys():
276
+ raise Exception(f"Error: {response_json['code']}, {response_json['message']}")
277
+ else:
278
+ return response_json['output']['text']
279
+ elif self.config.text['request_type'] == 'deepseek':
280
+ client = AsyncOpenAI(
281
+ api_key=self.config.text['api_key'],
282
+ base_url="https://api.deepseek.com/beta",
283
+ )
284
+ response = await client.chat.completions.create(
285
+ model="deepseek-chat",
286
+ messages=dialog,
287
+ temperature=temperature,
288
+ max_tokens=max_tokens,
289
+ top_p=top_p,
290
+ frequency_penalty=frequency_penalty,
291
+ presence_penalty=presence_penalty,
292
+ stream=False,
293
+ timeout=timeout,
294
+ )
295
+ self.prompt_tokens_used += response.usage.prompt_tokens
296
+ self.completion_tokens_used += response.usage.completion_tokens
297
+ self.request_number += 1
298
+ return response.choices[0].message.content
299
+ else:
300
+ print("ERROR: Wrong Config")
301
+ return "wrong config"
302
+
90
303
 
91
- def img_understand(self, img_path:Union[str, list[str]], prompt:str=None) -> str:
304
+ async def img_understand(self, img_path:Union[str, list[str]], prompt:str=None) -> str:
92
305
  """
93
306
  图像理解
94
307
  Image understanding
@@ -169,7 +382,7 @@ class UrbanLLM:
169
382
  print("ERROR: wrong image understanding type, only 'openai' and 'openai' is available")
170
383
  return "Error"
171
384
 
172
- def img_generate(self, prompt:str, size:str='512*512', quantity:int = 1):
385
+ async def img_generate(self, prompt:str, size:str='512*512', quantity:int = 1):
173
386
  """
174
387
  图像生成
175
388
  Image generation
@@ -197,4 +410,91 @@ class UrbanLLM:
197
410
  else:
198
411
  print('Failed, status_code: %s, code: %s, message: %s' %
199
412
  (rsp.status_code, rsp.code, rsp.message))
200
- return None
413
+ return None
414
+
415
+ async def convert_prompt_to_batch(self, prompt_raw, K, type:str="mr"):
416
+ """
417
+ 将单个目标的 prompt_raw 转换为适合一次批处理调用的 prompt_batch。
418
+ prompt_raw: 单目标的prompt,包含占位符!<INPUT X>!
419
+ K: 批处理的大小
420
+ type: 批处理prompt类型, 'mr': Multi-request; 'gr': Gather-request
421
+ 返回值: prompt_batch 和 variable_batch
422
+ """
423
+ # 分割说明性文本和待格式化的 prompt
424
+ sections = prompt_raw.split("<commentblockmarker>###</commentblockmarker>")
425
+ description_section = sections[0].strip() # 说明文本部分
426
+ prompt_section = sections[1].strip() # 待格式化的 prompt
427
+
428
+ # 1. 提取变量说明
429
+ variable_info = {}
430
+ variable_lines = re.findall(r"!<INPUT (\d+)>! -- (.+)", description_section)
431
+ for var, desc in variable_lines:
432
+ variable_info[int(var)] = desc.strip()
433
+
434
+ # 找到所有的占位符,如!<INPUT X>!
435
+ placeholders = re.findall(r"!<INPUT (\d+)>!", prompt_section)
436
+
437
+ # 构建prompt_batch,使用批次变量占位符,并合并为一个请求
438
+ prompt_batch = ""
439
+ for i in range(K):
440
+ current_prompt = prompt_section
441
+ for ph in placeholders:
442
+ current_prompt = current_prompt.replace(f"!<INPUT {ph}>!", f"{{batch[{i}][{ph}]}}")
443
+ prompt_batch += f"### Request {i + 1} ###\n" + current_prompt + "\n"
444
+
445
+ # 构建variable_batch,假设每个批次有K个值
446
+ variable_batch = [{int(ph): f"input_{i}_{ph}" for ph in placeholders} for i in range(K)]
447
+ if type == "mr":
448
+ return PromptBatch(prompt_batch, variable_batch, variable_info, K)
449
+ elif type == "gr":
450
+ messages = [{
451
+ "role": "system",
452
+ "content": """
453
+ You are given a batch prompt structured in multiple requests and a variable description. Your task is to simplify and aggregate this prompt by removing redundant request headers and combining variables across different requests into a more compact format.
454
+
455
+ ### Instructions:
456
+ 1. Identify all unique sections in the prompt that appear across different requests. These sections will be aggregated into a single instance.
457
+ 2. For each variable in the prompt (e.g., `{batch[0][0]}`, `{batch[0][1]}`), group them together across all requests, and present them as a list or in a numbered format.
458
+ 3. Remove all redundant request headers (e.g., `### Request 1 ###`) and replace them with a single unified block where variables are grouped by their positions.
459
+ 4. Ensure the output is structured clearly for LLM processing, avoiding unnecessary repetition.
460
+ 5. Ensure the output only contains the transformed prompt.
461
+
462
+ Here is an example transformation:
463
+ Original prompt:
464
+ ### Request 1 ###
465
+ {batch[0][0]}
466
+ In general, Your Lifestyle is as follows: {batch[0][1]}
467
+ Today is {batch[0][2]}, your wake up hour:
468
+ ### Request 2 ###
469
+ {batch[1][0]}
470
+ In general, Your Lifestyle is as follows: {batch[1][1]}
471
+ Today is {batch[1][2]}, your wake up hour:
472
+ Variable description:
473
+ batch[x][0] -- Identity Stable Set
474
+ batch[x][1] -- Lifestyle
475
+ batch[x][2] -- Day
476
+
477
+ Transformed prompt:
478
+ Your Identity Stable Set:
479
+ 1. {batch[0][0]}
480
+ 2. {batch[1][0]}
481
+ Your Lifestyle:
482
+ 1. {batch[0][1]}
483
+ 2. {batch[1][1]}
484
+ The Day:
485
+ 1. {batch[0][2]}
486
+ 2. {batch[1][2]}
487
+ Today, your wake up hour is:
488
+ """,
489
+ }]
490
+ var_des = ""
491
+ for var, desc in variable_info.items():
492
+ var_des += f"batch[x][{var}] -- {desc}\n"
493
+ user_prompt = f"""Here is the original batch prompt:
494
+ {prompt_batch}
495
+ The Variable description:
496
+ {var_des}
497
+ """
498
+ messages.append({'role': 'user', 'content': user_prompt})
499
+ prompt_batch = await self.atext_request(messages)
500
+ return PromptBatch(prompt_batch, variable_batch, variable_info, K)