python-flexeval 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. flexeval/__init__.py +11 -0
  2. flexeval/__main__.py +11 -0
  3. flexeval/classes/__init__.py +15 -0
  4. flexeval/classes/base.py +32 -0
  5. flexeval/classes/dataset.py +82 -0
  6. flexeval/classes/eval_runner.py +158 -0
  7. flexeval/classes/eval_set_run.py +32 -0
  8. flexeval/classes/message.py +183 -0
  9. flexeval/classes/metric.py +55 -0
  10. flexeval/classes/thread.py +79 -0
  11. flexeval/classes/tool_call.py +51 -0
  12. flexeval/classes/turn.py +206 -0
  13. flexeval/cli.py +104 -0
  14. flexeval/completions.py +147 -0
  15. flexeval/compute_metrics.py +788 -0
  16. flexeval/config.yaml +23 -0
  17. flexeval/configuration/__init__.py +1 -0
  18. flexeval/configuration/completion_functions.py +231 -0
  19. flexeval/configuration/evals.yaml +864 -0
  20. flexeval/configuration/function_metrics.py +650 -0
  21. flexeval/configuration/rubric_metrics.yaml +194 -0
  22. flexeval/data_loader.py +513 -0
  23. flexeval/db_utils.py +38 -0
  24. flexeval/dependency_graph.py +234 -0
  25. flexeval/eval_schema.json +256 -0
  26. flexeval/function_types.py +173 -0
  27. flexeval/helpers.py +52 -0
  28. flexeval/io/__init__.py +1 -0
  29. flexeval/io/parsers/yaml_parser.py +69 -0
  30. flexeval/log_utils.py +34 -0
  31. flexeval/metrics/__init__.py +8 -0
  32. flexeval/metrics/access.py +28 -0
  33. flexeval/metrics/save.py +39 -0
  34. flexeval/rubric.py +62 -0
  35. flexeval/run_utils.py +65 -0
  36. flexeval/runner.py +132 -0
  37. flexeval/schema/__init__.py +11 -0
  38. flexeval/schema/config_schema.py +46 -0
  39. flexeval/schema/eval_schema.py +163 -0
  40. flexeval/schema/evalrun_schema.py +97 -0
  41. flexeval/schema/rubric_schema.py +40 -0
  42. flexeval/schema/schema_utils.py +26 -0
  43. python_flexeval-0.1.5.dist-info/METADATA +118 -0
  44. python_flexeval-0.1.5.dist-info/RECORD +47 -0
  45. python_flexeval-0.1.5.dist-info/WHEEL +4 -0
  46. python_flexeval-0.1.5.dist-info/entry_points.txt +2 -0
  47. python_flexeval-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,650 @@
1
+ """Built-in function metrics that can be used in any configuration.
2
+
3
+ See :attr:`~flexeval.schema.evalrun_schema.EvalRun.add_default_functions`."""
4
+
5
+ import datetime
6
+ import json
7
+ import logging
8
+ import os
9
+ import re
10
+ from typing import Union
11
+
12
+ import openai
13
+ import textstat
14
+
15
+ from flexeval.classes.message import Message
16
+ from flexeval.classes.thread import Thread
17
+ from flexeval.classes.tool_call import ToolCall
18
+ from flexeval.classes.turn import Turn
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Example input types:
23
+ # - a single message as a string
24
+ # - an entire conversation (i.e., Thread) as a list of dictionaries
25
+ # - an object of type Thread, Turn, Message, or ToolCall.
26
+ # These objects have the same fields as in the correspondingly
27
+ # named databases, and can access objects at higher or lower
28
+ # levels of granularity. Examples are provided below for working
29
+ # with these objects.
30
+ turn_example = "This is a conversational turn."
31
+ conversation_example = [
32
+ {"role": "X1", "content": "Y1"},
33
+ {"role": "X2", "content": "Y2"},
34
+ ...,
35
+ ]
36
+
37
+
38
+ # A function template to process a single message
39
+ def process_single_message(
40
+ message: str,
41
+ ) -> Union[int, float, dict[str, Union[int, float]]]:
42
+ """
43
+ Process a single conversational message and return the desired output
44
+
45
+ Args:
46
+ message (str): a single conversational message as a string
47
+ NOTE: Metrics that take a string as input are valid at the Turn
48
+ and Message levels.
49
+
50
+ Returns:
51
+ an integer (e.g., 2), \
52
+ or a floating point number (e.g., 2.8), \
53
+ or a dictionary of metric/value pairs (e.g. {'metric1':value1, 'metric2':value2})
54
+ """
55
+ pass
56
+
57
+
58
+ # A function template to process an entire conversation
59
+ def process_conversation(
60
+ conversation: list,
61
+ ) -> Union[
62
+ int, float, dict[str, Union[int, float]], list[dict[str, Union[int, float]]]
63
+ ]:
64
+ """
65
+ Process an entire conversation and return the desired output
66
+
67
+ Args:
68
+ conversation (list): an entire conversation as a list
69
+ NOTE: Metrics that take a list as input are valid at the Thread
70
+ and Turn levels.
71
+ Returns:
72
+ an integer, e.g., 2 \
73
+ or a floating point number, e.g., 2.8 \
74
+ or a dictionary of metric/value pairs, e.g. {'metric1':value1, 'metric2':value2}\
75
+ or a list of dictionaries. The key can be either 'role' or 'metric'. \
76
+ e.g., [{"role":role1, "value":value1}, {"role":role2, "value":value2}, ...]
77
+ """
78
+ pass
79
+
80
+
81
+ def identity(object: Union[Thread, Turn, Message, ToolCall], **kwargs) -> dict:
82
+ """Returns a string of the object.
83
+
84
+ Args:
85
+ object (Union[Thread, Turn, Message, ToolCall]): Accepts any type of object.
86
+
87
+ Returns:
88
+ dict: Returns a dict.
89
+ """
90
+ if isinstance(object, Thread):
91
+ object_type = 0
92
+ elif isinstance(object, Turn):
93
+ object_type = 1
94
+ elif isinstance(object, Message):
95
+ object_type = 2
96
+ elif isinstance(object, ToolCall):
97
+ object_type = 3
98
+ else:
99
+ raise ValueError(f"Unknown object type {type(object)}.")
100
+ return {"object_type": object_type}
101
+
102
+
103
+ def constant(object: Union[Thread, Turn, Message, ToolCall], **kwargs) -> int | float:
104
+ """Returns a constant value.
105
+
106
+ Args:
107
+ object (Union[Thread, Turn, Message, ToolCall]): Accepts (and ignores) any type of object.
108
+ response (Union[float | int]): If provided in the kwargs, return response. Otherwise, return 0.
109
+
110
+ Returns:
111
+ int | float: The specified response, or 0.
112
+ """
113
+ response = 0
114
+ if "response" in kwargs:
115
+ response = kwargs["response"]
116
+ return response
117
+
118
+
119
+ def is_role(object: Union[Turn, Message], role: str) -> dict:
120
+ """
121
+ Returns 1 if the role for this Turn or Message matches the passed in role,
122
+ and 0 otherwise.
123
+
124
+ Args:
125
+ object: the Turn or Message
126
+ role: a string with the desired role to check against
127
+ """
128
+ return {role: int(object.role == role)}
129
+
130
+
131
+ def is_langgraph_type(object: Union[Message], type: str) -> dict:
132
+ """
133
+ Return 1 is the langgraph type for this Message matches the passed in type,
134
+ and 0 otherwise.
135
+
136
+ Args:
137
+ object: the Message
138
+ type: a string with the desired type to check against
139
+ """
140
+ return {type: int(object.langgraph_type == type)}
141
+
142
+
143
+ def index_in_thread(object: Union[Turn, Message]) -> int:
144
+ return object.index_in_thread
145
+
146
+
147
+ def value_counts_by_tool_name(turn: list, json_key: str) -> dict:
148
+ """
149
+ Counts the occurrences of particular values in the text content of tool call in the conversation.
150
+ Assumes the role will be tool, and that kwargs contains the argument json_key. values associated with
151
+ that json_key for a specific tool name are separately aggregated with counts.
152
+
153
+ Args:
154
+ conversation (List[Dict[str, Any]]): A list of dictionaries representing conversational turns.
155
+ Each dictionary should have a 'role' key indicating the role of the participant.
156
+ json_key: string that represents the key to look for in the content
157
+ of the tool call text
158
+
159
+ Returns:
160
+ list of name/value pairs for each parameter and function name combo
161
+ """
162
+
163
+ # Count number of tool calls by name
164
+ counter = {}
165
+ for entry in turn:
166
+ if entry["role"] == "tool":
167
+ # Find the text content
168
+ for content_dict in entry["content"]:
169
+ if content_dict["type"] == "text":
170
+ json_content_list = json.loads(content_dict["text"])
171
+ for json_dict in json_content_list:
172
+ if json_key in json_dict:
173
+ key = entry["name"] + "_" + json_dict[json_key]
174
+ counter[key] = counter.get(key, 0) + 1
175
+
176
+ return counter
177
+
178
+
179
+ def message_matches_regex(message: Message, expression: str) -> dict:
180
+ """Determines whether a message matches a regular expression specified by the user
181
+
182
+ Outputs the number of matches detected using Pattern.findall()
183
+ """
184
+
185
+ # Compile the regular expression R
186
+ pattern = re.compile(expression)
187
+
188
+ # Use the fullmatch method to check if the entire string X matches the pattern
189
+ match = pattern.findall(message.content)
190
+
191
+ if match:
192
+ return {expression: len(match)}
193
+ else:
194
+ return {expression: 0}
195
+
196
+
197
+ def count_of_parts_matching_regex(
198
+ object: Union[Thread, Turn, Message], expression: str
199
+ ) -> int:
200
+ """Determines the total number of messages in this object
201
+ matching a regular expression specified by the user. Ignores tool calls in object.
202
+
203
+ Outputs the sum of the number of matches detected using Pattern.findall() across
204
+ all entries in the object.
205
+ """
206
+ total_matches = {expression: 0}
207
+ if isinstance(object, (Thread, Turn)):
208
+ messages_to_match = object.messages
209
+ else:
210
+ messages_to_match = [object]
211
+
212
+ for message in messages_to_match:
213
+ total_matches[expression] += message_matches_regex(message, expression)[
214
+ expression
215
+ ]
216
+
217
+ return total_matches
218
+
219
+
220
+ def tool_was_called(object: Union[Thread, Turn, Message]) -> float:
221
+ """Returns 1 if a tool was called, and 0 otherwise"""
222
+ for tc in object.toolcalls:
223
+ return 1
224
+ return 0
225
+
226
+
227
+ def count_tool_calls_by_name(object: Union[Thread, Turn, Message, ToolCall]) -> dict:
228
+ """
229
+ Counts how many times a ToolCall was used to call functions, with metric names
230
+ equal to function names.
231
+
232
+ NOTE: This function provides an example of how to go from higher levels of granularity
233
+ (e.g., Thread) to lower levels of granularity (e.g., ToolCall).
234
+ """
235
+ # Extract ToolCall objects based on the type of object being passed in
236
+ toolcalls = []
237
+ if isinstance(object, (Thread, Turn)):
238
+ for message in object.messages:
239
+ toolcalls += [toolcall for toolcall in message.toolcalls]
240
+ elif isinstance(object, Message):
241
+ toolcalls += [toolcall for toolcall in object.toolcalls]
242
+ else: # Must be just a tool call
243
+ toolcalls.append(object)
244
+
245
+ # Count the toolcalls
246
+ toolcall_counts = {}
247
+ for toolcall in toolcalls:
248
+ if toolcall.function_name not in toolcall_counts:
249
+ toolcall_counts[toolcall.function_name] = 0
250
+ toolcall_counts[toolcall.function_name] = (
251
+ toolcall_counts[toolcall.function_name] + 1
252
+ )
253
+
254
+ return toolcall_counts
255
+
256
+
257
+ def count_numeric_tool_call_params_by_name(toolcall: ToolCall) -> list[dict]:
258
+ """Extracts the values of all numeric ToolCall parameter inputs,
259
+ with metric_name being the name of the corresponding parameter.
260
+
261
+ Args:
262
+ toolcall (ToolCall): The tool call.
263
+
264
+ Returns:
265
+ list[dict]: List of key -> numeric value pairs in the tool call.
266
+ """
267
+ results = []
268
+ toolcall_args = json.loads(toolcall.args)
269
+ for arg_name, arg_value in toolcall_args.items():
270
+ try:
271
+ numeric_val = float(arg_value)
272
+ results.append({"name": arg_name, "value": numeric_val})
273
+ except ValueError:
274
+ pass
275
+ return results
276
+
277
+
278
+ def count_llm_models(thread: Thread) -> dict:
279
+ """Provides a count of messages in the thread produced by each LLM model.
280
+ Useful for quantifying which LLM generated the results - and agents can have more than 1 type.
281
+ """
282
+ results = {}
283
+ for message in thread.messages:
284
+ if message.model_name is not None:
285
+ results[message.model_name] = results.get(message.model_name, 0) + 1
286
+ return results
287
+
288
+
289
+ def count_tool_calls(object: Union[Thread, Turn, Message]) -> dict:
290
+ """
291
+ Provides a count of how many total tools calls there are in this Thread/Turn/Message.
292
+ Differs from count_tool_calls_by_name because it does not return the names of the tool calls.
293
+ """
294
+ # Extract ToolCall objects based on the type of object being passed in
295
+ toolcalls = []
296
+ if isinstance(object, (Thread, Turn)):
297
+ for message in object.messages:
298
+ toolcalls += [toolcall for toolcall in message.toolcalls]
299
+ else: # must be a Message
300
+ toolcalls += [toolcall for toolcall in object.toolcalls]
301
+ return len(toolcalls)
302
+
303
+
304
+ # def count_tool_calls_by_name(object: Union[Thread, Turn, Message, ToolCall]) -> list:
305
+ # # Extract ToolCall objects based on the type of object being passed in
306
+ # toolcalls = []
307
+ # if isinstance(object, (Thread, Turn)):
308
+ # for message in object.messages:
309
+ # toolcalls += [toolcall for toolcall in message.toolcalls]
310
+ # elif isinstance(object, Message):
311
+ # toolcalls += [toolcall for toolcall in object.toolcalls]
312
+ # else: # Must be just a tool call
313
+ # toolcalls.append(object)
314
+
315
+ # # Count the toolcalls
316
+ # toolcall_counts = {}
317
+ # for toolcall in toolcalls:
318
+ # if toolcall.function_name not in toolcall_counts:
319
+ # toolcall_counts[toolcall.function_name] = 0
320
+ # toolcall_counts[toolcall.function_name] = (
321
+ # toolcall_counts[toolcall.function_name] + 1
322
+ # )
323
+
324
+ # # Convert to a list of name: value dictionaries
325
+ # results = []
326
+ # for toolcall_name, toolcall_count in toolcall_counts.items():
327
+ # results.append({"name": toolcall_name, "value": toolcall_count})
328
+ # return results
329
+
330
+
331
+ def count_messages(object: Union[Thread, Turn]) -> int:
332
+ """
333
+ Calculate the number of conversational messages in the given Thread or Turn.
334
+ Excludes any system messages.
335
+ A message is counted even if the content for that action was blank (e.g., a blank message
336
+ associated with a tool call).
337
+
338
+ Args:
339
+ Turn or Thread
340
+
341
+ Returns:
342
+ int: Count of messages.
343
+ """
344
+ return len(object.messages)
345
+
346
+
347
+ def count_turns(object: Thread) -> int:
348
+ """
349
+ Calculate the number of conversational turns in a thread.
350
+
351
+ Args:
352
+ Thread
353
+
354
+ Returns:
355
+ int: Count of turns.
356
+ """
357
+ return len(object.turns)
358
+
359
+
360
+ def count_messages_per_role(
361
+ object: Union[Thread, Turn], use_langgraph_roles=False
362
+ ) -> list:
363
+ """
364
+ Calculate the number of conversational messages for each role. Excludes the system prompt.
365
+ A message is counted even if the content for that action was blank (e.g., a blank message
366
+ associated with a tool call).
367
+
368
+ Args:
369
+ Turn or Thread
370
+
371
+ Returns:
372
+ dict: A dictionary with roles as keys roles and values as counts of messages
373
+ """
374
+ results = {}
375
+ for message in object.messages:
376
+ if use_langgraph_roles:
377
+ role = message.langgraph_type
378
+ else:
379
+ role = message.role
380
+ results[role] = results.get(role, 0) + 1
381
+ return results
382
+
383
+
384
+ def is_last_turn_in_thread(turn: Turn) -> int:
385
+ """
386
+ Returns 1 if this turn is the final turn in its thread, and 0 otherwise.
387
+
388
+ Args:
389
+ turn: turn to evaluate
390
+
391
+ Returns:
392
+ int: 1 for this being the temporally last turn in the thread, 0 otherwise
393
+ """
394
+ from peewee import fn
395
+
396
+ # Select the id of the Turn in the current thread that has the max value
397
+ max_turn_id = (
398
+ Turn.select(fn.max(Turn.id)).where(Turn.thread_id == turn.thread.id).scalar()
399
+ )
400
+ return int(max_turn_id == turn.id)
401
+
402
+
403
+ def count_emojis(turn: str) -> int:
404
+ """
405
+ Calculate the number of emojis in a given text string.
406
+
407
+ Args:
408
+ turn (str): The input text string to be evaluated.
409
+
410
+ Returns:
411
+ int: The number of emojis in the input text.
412
+ """
413
+ emoji_pattern = re.compile(
414
+ "["
415
+ "\U0001f600-\U0001f64f" # emoticons
416
+ "\U0001f300-\U0001f5ff" # symbols & pictographs
417
+ "\U0001f680-\U0001f6ff" # transport & map symbols
418
+ "\U0001f1e0-\U0001f1ff" # flags (iOS)
419
+ "\U00002702-\U000027b0" # Dingbats
420
+ "\U000024c2-\U0001f251"
421
+ "]+",
422
+ flags=re.UNICODE,
423
+ )
424
+ return len(emoji_pattern.findall(turn))
425
+
426
+
427
+ def string_length(object: Union[Thread, Turn, Message]) -> int:
428
+ """
429
+ Calculate the length of the content.
430
+
431
+ Args:
432
+ object (Union[Thread, Turn, Message]):
433
+
434
+ Returns:
435
+ int: The length of the content of the messages (added together for
436
+ thread and turn that may contain more than one message)
437
+ """
438
+ content = object.get_content()
439
+ length = 0
440
+ if isinstance(content, str):
441
+ length = len(content)
442
+ else: # list
443
+ # Sum up the lengths of the individual contents
444
+ for role_content_dict in content:
445
+ length += len(role_content_dict.get("content", ""))
446
+
447
+ return length
448
+
449
+
450
+ def flesch_reading_ease(turn: str) -> float:
451
+ """
452
+ Calculate the Flesch Reading Ease score for a given text string.
453
+
454
+ The Flesch Reading Ease score is a readability test designed to indicate how difficult a passage
455
+ in English is to understand. Higher scores indicate material that is easier to read; lower scores
456
+ indicate material that is more difficult to read.
457
+
458
+ Args:
459
+ turn (str): The input text string to be evaluated.
460
+
461
+ Returns:
462
+ float: The Flesch Reading Ease score of the input text.
463
+ """
464
+ if turn.strip() == "":
465
+ pass
466
+ reading_ease = textstat.flesch_reading_ease(turn)
467
+ logger.debug(f"Text '{turn}' has a Flesch Reading Ease score of {reading_ease}.")
468
+ return reading_ease
469
+
470
+
471
+ def flesch_kincaid_grade(turn: str) -> float:
472
+ """
473
+ Calculate the Flesch-Kincaid Grade Level score for a given text string.
474
+
475
+ The Flesch-Kincaid Grade Level score is a readability test designed to indicate the U.S. school
476
+ grade level of the text. Higher scores indicate material that is more difficult to read and understand,
477
+ suitable for higher grade levels.
478
+
479
+ Args:
480
+ turn (str): The input text string to be evaluated.
481
+
482
+ Returns:
483
+ float: The Flesch-Kincaid Grade Level score of the input text.
484
+ """
485
+ return textstat.flesch_kincaid_grade(turn)
486
+
487
+
488
+ def openai_moderation_api(turn: str, **kwargs) -> dict:
489
+ """
490
+ Calls the OpenAI Moderation API to analyze the given conversational turn for content moderation.
491
+ Since the input is a string, it'll concatenate all the "content" together and pass it in
492
+
493
+ Args:
494
+ turn (str): The conversational turn to be analyzed.
495
+ **kwargs (Any): Ignored for now
496
+
497
+ Returns:
498
+ Dict[str, float]: A dictionary of category scores from the moderation API response.
499
+ """
500
+ client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
501
+ response = client.moderations.create(
502
+ model="omni-moderation-latest", input=turn, **kwargs
503
+ )
504
+ return response.results[0].category_scores.model_dump(
505
+ exclude_unset=True, by_alias=True
506
+ )
507
+
508
+
509
+ def count_errors(object: Union[Thread, Turn, Message, ToolCall]) -> dict:
510
+ """If a Thread, counts the errors of each type in the thread.
511
+ If a Turn, Message, or ToolCall, ditto.
512
+
513
+ It does this by iterating through ToolCalls and identifying whether there are
514
+ entries like "*_errors" in tool_call.additional_kwargs
515
+
516
+ If a ToolCall, returns 1 if there is an error of each type
517
+ {
518
+ "python_errors": 3,
519
+ "javascript_errors": 1
520
+ }
521
+ """
522
+ if isinstance(object, ToolCall):
523
+ return {
524
+ i: 1
525
+ for i in json.loads(object.additional_kwargs)
526
+ if (
527
+ i.endswith("_errors")
528
+ and json.loads(object.additional_kwargs)[i] is not None
529
+ )
530
+ }
531
+ else:
532
+ results = {}
533
+ for toolcall in object.toolcalls:
534
+ keys = [
535
+ i
536
+ for i in json.loads(toolcall.additional_kwargs)
537
+ if i.endswith("_errors")
538
+ ]
539
+ for key in keys:
540
+ if json.loads(toolcall.additional_kwargs).get(key, None) is not None:
541
+ results[key] = results.get(key, 0) + 1
542
+ return results
543
+
544
+
545
+ def count_tokens(object: Union[Thread, Turn, Message]) -> dict:
546
+ """
547
+ Counts how many prompt_tokens and completion_tokens tokens are used.
548
+
549
+ These values are record at the Message level, so this function sums over messages
550
+ if the input type is Thread or Turn
551
+ """
552
+ # Extract ToolCall objects based on the type of object being passed in
553
+ prompt_tokens = 0
554
+ completion_tokens = 0
555
+ if isinstance(object, (Thread, Turn)):
556
+ for message in object.messages:
557
+ if message.prompt_tokens is not None:
558
+ prompt_tokens += message.prompt_tokens
559
+ if message.completion_tokens is not None:
560
+ completion_tokens += message.completion_tokens
561
+ elif isinstance(object, Message):
562
+ if object.prompt_tokens is not None:
563
+ prompt_tokens += object.prompt_tokens
564
+ if object.completion_tokens is not None:
565
+ completion_tokens += object.completion_tokens
566
+
567
+ return {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens}
568
+
569
+
570
+ def latency(object: Union[Thread, Turn, Message]) -> float:
571
+ """
572
+ Returns the estimated time, in seconds, that it took for the Turead/Turn/Message to be generated, in seconds.
573
+
574
+ For Turns and Messages, this is done by comparing the timestamp of the Turn/Message, which indicates the
575
+ output time of that Turn/Message - to the timestamp fo the previous Turn/Message.
576
+ For example, if a Message is generated at 1:27.3 but the previous message was generated at 1:23.1, the latency was 4.2 seconds.
577
+
578
+ For Threads, the latenecy is calculated as the time difference, again in seconds, between the first and last message.
579
+
580
+ """
581
+ if isinstance(object, Thread):
582
+ # difference between earliest and latest message
583
+ first_message_ts = object.messages.order_by(Message.langgraph_ts)[
584
+ 0
585
+ ].langgraph_ts
586
+ final_message_ts = object.messages.order_by(Message.langgraph_ts.desc())[
587
+ 0
588
+ ].langgraph_ts
589
+ first_datetime = datetime.datetime.fromisoformat(first_message_ts)
590
+ final_datetime = datetime.datetime.fromisoformat(final_message_ts)
591
+ difference = final_datetime - first_datetime
592
+ latency = difference.total_seconds()
593
+ elif isinstance(object, (Message)):
594
+ # difference between this message and the one right before it in the thread
595
+ # easy way to do this is to take the time difference between this the closest
596
+ # message just before this one
597
+ # If there is NO messge right before it, return None - there is no way to know the latency
598
+ my_datetime = object.langgraph_ts
599
+ # message in the same thread, earlier than this one, and pick the last langgraph_ts in that list
600
+ message_just_before_this_one = (
601
+ Message.select()
602
+ .where(Message.thread == object.thread, Message.langgraph_ts < my_datetime)
603
+ .order_by(Message.langgraph_ts.desc())
604
+ .first()
605
+ )
606
+ if message_just_before_this_one is None:
607
+ latency = 0.0
608
+ else:
609
+ difference = datetime.datetime.fromisoformat(
610
+ my_datetime
611
+ ) - datetime.datetime.fromisoformat(
612
+ message_just_before_this_one.langgraph_ts
613
+ )
614
+ latency = difference.total_seconds()
615
+
616
+ elif isinstance(object, Turn):
617
+ # if it's a thread,
618
+ # you want the difference between the LAST message in the turn
619
+ # and the LAST message in the previous turn
620
+ # or equivalently, the closest previous timestamp from any turn other than yours
621
+
622
+ # get timestamp of last message associated with this Turn
623
+ my_datetime = (
624
+ Message.select()
625
+ .where(Message.turn == object)
626
+ .order_by(Message.langgraph_ts.desc())
627
+ .first()
628
+ ).langgraph_ts
629
+ # Then get the mmessage in the same Thread
630
+ # with a timestamp just before this one
631
+ message_just_before_this_one = (
632
+ Message.select()
633
+ .where(
634
+ Message.thread == object.thread,
635
+ Message.langgraph_ts < my_datetime,
636
+ )
637
+ .order_by(Message.langgraph_ts.desc())
638
+ .first()
639
+ )
640
+ if message_just_before_this_one is None:
641
+ latency = 0.0
642
+ else:
643
+ difference = datetime.datetime.fromisoformat(
644
+ my_datetime
645
+ ) - datetime.datetime.fromisoformat(
646
+ message_just_before_this_one.langgraph_ts
647
+ )
648
+ latency = difference.total_seconds()
649
+
650
+ return latency