lionagi 0.0.305__py3-none-any.whl → 0.0.307__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (84) hide show
  1. lionagi/__init__.py +2 -5
  2. lionagi/core/__init__.py +7 -4
  3. lionagi/core/agent/__init__.py +3 -0
  4. lionagi/core/agent/base_agent.py +46 -0
  5. lionagi/core/branch/__init__.py +4 -0
  6. lionagi/core/branch/base/__init__.py +0 -0
  7. lionagi/core/branch/base_branch.py +100 -78
  8. lionagi/core/branch/branch.py +22 -34
  9. lionagi/core/branch/branch_flow_mixin.py +3 -7
  10. lionagi/core/branch/executable_branch.py +192 -0
  11. lionagi/core/branch/util.py +77 -162
  12. lionagi/core/direct/__init__.py +13 -0
  13. lionagi/core/direct/parallel_predict.py +127 -0
  14. lionagi/core/direct/parallel_react.py +0 -0
  15. lionagi/core/direct/parallel_score.py +0 -0
  16. lionagi/core/direct/parallel_select.py +0 -0
  17. lionagi/core/direct/parallel_sentiment.py +0 -0
  18. lionagi/core/direct/predict.py +174 -0
  19. lionagi/core/direct/react.py +33 -0
  20. lionagi/core/direct/score.py +163 -0
  21. lionagi/core/direct/select.py +144 -0
  22. lionagi/core/direct/sentiment.py +51 -0
  23. lionagi/core/direct/utils.py +83 -0
  24. lionagi/core/flow/__init__.py +0 -3
  25. lionagi/core/flow/monoflow/{mono_react.py → ReAct.py} +52 -9
  26. lionagi/core/flow/monoflow/__init__.py +9 -0
  27. lionagi/core/flow/monoflow/{mono_chat.py → chat.py} +11 -11
  28. lionagi/core/flow/monoflow/{mono_chat_mixin.py → chat_mixin.py} +33 -27
  29. lionagi/core/flow/monoflow/{mono_followup.py → followup.py} +7 -6
  30. lionagi/core/flow/polyflow/__init__.py +1 -0
  31. lionagi/core/flow/polyflow/{polychat.py → chat.py} +15 -3
  32. lionagi/core/mail/__init__.py +8 -0
  33. lionagi/core/mail/mail_manager.py +88 -40
  34. lionagi/core/mail/schema.py +32 -6
  35. lionagi/core/messages/__init__.py +3 -0
  36. lionagi/core/messages/schema.py +56 -25
  37. lionagi/core/prompt/__init__.py +0 -0
  38. lionagi/core/prompt/prompt_template.py +0 -0
  39. lionagi/core/schema/__init__.py +7 -5
  40. lionagi/core/schema/action_node.py +29 -0
  41. lionagi/core/schema/base_mixin.py +56 -59
  42. lionagi/core/schema/base_node.py +35 -38
  43. lionagi/core/schema/condition.py +24 -0
  44. lionagi/core/schema/data_logger.py +98 -98
  45. lionagi/core/schema/data_node.py +19 -19
  46. lionagi/core/schema/prompt_template.py +0 -0
  47. lionagi/core/schema/structure.py +293 -190
  48. lionagi/core/session/__init__.py +1 -3
  49. lionagi/core/session/session.py +196 -214
  50. lionagi/core/tool/tool_manager.py +95 -103
  51. lionagi/integrations/__init__.py +1 -3
  52. lionagi/integrations/bridge/langchain_/documents.py +17 -18
  53. lionagi/integrations/bridge/langchain_/langchain_bridge.py +14 -14
  54. lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +22 -22
  55. lionagi/integrations/bridge/llamaindex_/node_parser.py +12 -12
  56. lionagi/integrations/bridge/llamaindex_/reader.py +11 -11
  57. lionagi/integrations/bridge/llamaindex_/textnode.py +7 -7
  58. lionagi/integrations/config/openrouter_configs.py +0 -1
  59. lionagi/integrations/provider/oai.py +26 -26
  60. lionagi/integrations/provider/services.py +38 -38
  61. lionagi/libs/__init__.py +34 -1
  62. lionagi/libs/ln_api.py +211 -221
  63. lionagi/libs/ln_async.py +53 -60
  64. lionagi/libs/ln_convert.py +118 -120
  65. lionagi/libs/ln_dataframe.py +32 -33
  66. lionagi/libs/ln_func_call.py +334 -342
  67. lionagi/libs/ln_nested.py +99 -107
  68. lionagi/libs/ln_parse.py +175 -158
  69. lionagi/libs/sys_util.py +52 -52
  70. lionagi/tests/test_core/test_base_branch.py +427 -427
  71. lionagi/tests/test_core/test_branch.py +292 -292
  72. lionagi/tests/test_core/test_mail_manager.py +57 -57
  73. lionagi/tests/test_core/test_session.py +254 -266
  74. lionagi/tests/test_core/test_session_base_util.py +299 -300
  75. lionagi/tests/test_core/test_tool_manager.py +70 -74
  76. lionagi/tests/test_libs/test_nested.py +2 -7
  77. lionagi/tests/test_libs/test_parse.py +2 -2
  78. lionagi/version.py +1 -1
  79. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/METADATA +4 -2
  80. lionagi-0.0.307.dist-info/RECORD +115 -0
  81. lionagi-0.0.305.dist-info/RECORD +0 -94
  82. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/LICENSE +0 -0
  83. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/WHEEL +0 -0
  84. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/top_level.txt +0 -0
lionagi/libs/ln_parse.py CHANGED
@@ -1,11 +1,11 @@
1
1
  import re
2
2
  import inspect
3
+ import itertools
3
4
  from collections.abc import Callable
4
5
  from typing import Any
5
6
  import numpy as np
6
7
  import lionagi.libs.ln_convert as convert
7
8
 
8
-
9
9
  md_json_char_map = {"\n": "\\n", "\r": "\\r", "\t": "\\t", '"': '\\"'}
10
10
 
11
11
 
@@ -20,29 +20,35 @@ class ParseUtil:
20
20
  the string by appending necessary closing characters before retrying.
21
21
 
22
22
  Args:
23
- s (str): The JSON string to parse.
24
- strict (bool, optional): If True, enforces strict JSON syntax. Defaults to False.
23
+ s (str): The JSON string to parse.
24
+ strict (bool, optional): If True, enforces strict JSON syntax. Defaults to False.
25
25
 
26
26
  Returns:
27
- The parsed JSON object, typically a dictionary or list.
27
+ The parsed JSON object, typically a dictionary or list.
28
28
 
29
29
  Raises:
30
- ValueError: If parsing fails even after attempting to correct the string.
30
+ ValueError: If parsing fails even after attempting to correct the string.
31
31
 
32
32
  Example:
33
- >>> fuzzy_parse_json('{"name": "John", "age": 30, "city": "New York"')
34
- {'name': 'John', 'age': 30, 'city': 'New York'}
33
+ >>> fuzzy_parse_json('{"name": "John", "age": 30, "city": "New York"')
34
+ {'name': 'John', 'age': 30, 'city': 'New York'}
35
35
  """
36
36
  try:
37
37
  return convert.to_dict(str_to_parse, strict=strict)
38
- except:
38
+ except Exception:
39
39
  fixed_s = ParseUtil.fix_json_string(str_to_parse)
40
40
  try:
41
41
  return convert.to_dict(fixed_s, strict=strict)
42
- except Exception as e:
43
- raise ValueError(
44
- f"Failed to parse JSON even after fixing attempts: {e}"
45
- )
42
+
43
+ except Exception:
44
+ try:
45
+ fixed_s = fixed_s.replace("'", '"')
46
+ return convert.to_dict(fixed_s, strict=strict)
47
+
48
+ except Exception as e:
49
+ raise ValueError(
50
+ f"Failed to parse JSON even after fixing attempts: {e}"
51
+ ) from e
46
52
 
47
53
  @staticmethod
48
54
  def fix_json_string(str_to_parse: str) -> str:
@@ -70,17 +76,17 @@ class ParseUtil:
70
76
  a default mapping is used.
71
77
 
72
78
  Args:
73
- value: The string to be escaped.
74
- char_map: An optional dictionary mapping characters to their escaped versions.
75
- If not provided, a default mapping that escapes newlines, carriage returns,
76
- tabs, and double quotes is used.
79
+ value: The string to be escaped.
80
+ char_map: An optional dictionary mapping characters to their escaped versions.
81
+ If not provided, a default mapping that escapes newlines, carriage returns,
82
+ tabs, and double quotes is used.
77
83
 
78
84
  Returns:
79
- The escaped JSON string.
85
+ The escaped JSON string.
80
86
 
81
87
  Examples:
82
- >>> escape_chars_in_json('Line 1\nLine 2')
83
- 'Line 1\\nLine 2'
88
+ >>> escape_chars_in_json('Line 1\nLine 2')
89
+ 'Line 1\\nLine 2'
84
90
  """
85
91
 
86
92
  def replacement(match):
@@ -108,22 +114,22 @@ class ParseUtil:
108
114
  filtered by language. If a code block is found, it is parsed using the provided parser function.
109
115
 
110
116
  Args:
111
- str_to_parse: The Markdown content to search.
112
- language: An optional language specifier for the code block. If provided,
113
- only code blocks of this language are considered.
114
- regex_pattern: An optional regular expression pattern to use for finding the code block.
115
- If provided, it overrides the language parameter.
116
- parser: A function to parse the extracted code block string.
117
+ str_to_parse: The Markdown content to search.
118
+ language: An optional language specifier for the code block. If provided,
119
+ only code blocks of this language are considered.
120
+ regex_pattern: An optional regular expression pattern to use for finding the code block.
121
+ If provided, it overrides the language parameter.
122
+ parser: A function to parse the extracted code block string.
117
123
 
118
124
  Returns:
119
- The result of parsing the code block with the provided parser function.
125
+ The result of parsing the code block with the provided parser function.
120
126
 
121
127
  Raises:
122
- ValueError: If no code block is found in the Markdown content.
128
+ ValueError: If no code block is found in the Markdown content.
123
129
 
124
130
  Examples:
125
- >>> extract_code_block('```python\\nprint("Hello, world!")\\n```', language='python', parser=lambda x: x)
126
- 'print("Hello, world!")'
131
+ >>> extract_code_block('```python\\nprint("Hello, world!")\\n```', language='python', parser=lambda x: x)
132
+ 'print("Hello, world!")'
127
133
  """
128
134
 
129
135
  if language:
@@ -134,7 +140,7 @@ class ParseUtil:
134
140
  match = re.search(regex_pattern, str_to_parse, re.DOTALL)
135
141
  code_str = ""
136
142
  if match:
137
- code_str = match.group(1).strip()
143
+ code_str = match[1].strip()
138
144
  else:
139
145
  raise ValueError(
140
146
  f"No {language or 'specified'} code block found in the Markdown content."
@@ -156,29 +162,28 @@ class ParseUtil:
156
162
  Markdown string. It then optionally verifies that the parsed JSON object contains all expected keys.
157
163
 
158
164
  Args:
159
- str_to_parse: The Markdown content to parse.
160
- expected_keys: An optional list of keys expected to be present in the parsed JSON object.
161
- parser: An optional function to parse the extracted code block. If not provided,
162
- `fuzzy_parse_json` is used with default settings.
165
+ str_to_parse: The Markdown content to parse.
166
+ expected_keys: An optional list of keys expected to be present in the parsed JSON object.
167
+ parser: An optional function to parse the extracted code block. If not provided,
168
+ `fuzzy_parse_json` is used with default settings.
163
169
 
164
170
  Returns:
165
- The parsed JSON object from the Markdown content.
171
+ The parsed JSON object from the Markdown content.
166
172
 
167
173
  Raises:
168
- ValueError: If the JSON code block is missing, or if any of the expected keys are missing
169
- from the parsed JSON object.
174
+ ValueError: If the JSON code block is missing, or if any of the expected keys are missing
175
+ from the parsed JSON object.
170
176
 
171
177
  Examples:
172
- >>> md_to_json('```json\\n{"key": "value"}\\n```', expected_keys=['key'])
173
- {'key': 'value'}
178
+ >>> md_to_json('```json\\n{"key": "value"}\\n```', expected_keys=['key'])
179
+ {'key': 'value'}
174
180
  """
175
181
  json_obj = ParseUtil.extract_code_block(
176
182
  str_to_parse, language="json", parser=parser or ParseUtil.fuzzy_parse_json
177
183
  )
178
184
 
179
185
  if expected_keys:
180
- missing_keys = [key for key in expected_keys if key not in json_obj]
181
- if missing_keys:
186
+ if missing_keys := [key for key in expected_keys if key not in json_obj]:
182
187
  raise ValueError(
183
188
  f"Missing expected keys in JSON object: {', '.join(missing_keys)}"
184
189
  )
@@ -192,26 +197,26 @@ class ParseUtil:
192
197
  docstring following the Google style format.
193
198
 
194
199
  Args:
195
- func (Callable): The function from which to extract docstring details.
200
+ func (Callable): The function from which to extract docstring details.
196
201
 
197
202
  Returns:
198
- Tuple[str, Dict[str, str]]: A tuple containing the function description
199
- and a dictionary with parameter names as keys and their descriptions as values.
203
+ Tuple[str, Dict[str, str]]: A tuple containing the function description
204
+ and a dictionary with parameter names as keys and their descriptions as values.
200
205
 
201
206
  Examples:
202
- >>> def example_function(param1: int, param2: str):
203
- ... '''Example function.
204
- ...
205
- ... Args:
206
- ... param1 (int): The first parameter.
207
- ... param2 (str): The second parameter.
208
- ... '''
209
- ... pass
210
- >>> description, params = _extract_docstring_details_google(example_function)
211
- >>> description
212
- 'Example function.'
213
- >>> params == {'param1': 'The first parameter.', 'param2': 'The second parameter.'}
214
- True
207
+ >>> def example_function(param1: int, param2: str):
208
+ ... '''Example function.
209
+ ...
210
+ ... Args:
211
+ ... param1 (int): The first parameter.
212
+ ... param2 (str): The second parameter.
213
+ ... '''
214
+ ... pass
215
+ >>> description, params = _extract_docstring_details_google(example_function)
216
+ >>> description
217
+ 'Example function.'
218
+ >>> params == {'param1': 'The first parameter.', 'param2': 'The second parameter.'}
219
+ True
215
220
  """
216
221
  docstring = inspect.getdoc(func)
217
222
  if not docstring:
@@ -219,19 +224,21 @@ class ParseUtil:
219
224
  lines = docstring.split("\n")
220
225
  func_description = lines[0].strip()
221
226
 
222
- param_start_pos = 0
223
227
  lines_len = len(lines)
224
228
 
225
229
  params_description = {}
226
- for i in range(1, lines_len):
227
- if (
228
- lines[i].startswith("Args")
229
- or lines[i].startswith("Arguments")
230
- or lines[i].startswith("Parameters")
231
- ):
232
- param_start_pos = i + 1
233
- break
234
-
230
+ param_start_pos = next(
231
+ (
232
+ i + 1
233
+ for i in range(1, lines_len)
234
+ if (
235
+ lines[i].startswith("Args")
236
+ or lines[i].startswith("Arguments")
237
+ or lines[i].startswith("Parameters")
238
+ )
239
+ ),
240
+ 0,
241
+ )
235
242
  current_param = None
236
243
  for i in range(param_start_pos, lines_len):
237
244
  if lines[i] == "":
@@ -239,7 +246,7 @@ class ParseUtil:
239
246
  elif lines[i].startswith(" "):
240
247
  param_desc = lines[i].split(":", 1)
241
248
  if len(param_desc) == 1:
242
- params_description[current_param] += " " + param_desc[0].strip()
249
+ params_description[current_param] += f" {param_desc[0].strip()}"
243
250
  continue
244
251
  param, desc = param_desc
245
252
  param = param.split("(")[0].strip()
@@ -256,27 +263,27 @@ class ParseUtil:
256
263
  docstring following the reStructuredText (reST) style format.
257
264
 
258
265
  Args:
259
- func (Callable): The function from which to extract docstring details.
266
+ func (Callable): The function from which to extract docstring details.
260
267
 
261
268
  Returns:
262
- Tuple[str, Dict[str, str]]: A tuple containing the function description
263
- and a dictionary with parameter names as keys and their descriptions as values.
269
+ Tuple[str, Dict[str, str]]: A tuple containing the function description
270
+ and a dictionary with parameter names as keys and their descriptions as values.
264
271
 
265
272
  Examples:
266
- >>> def example_function(param1: int, param2: str):
267
- ... '''Example function.
268
- ...
269
- ... :param param1: The first parameter.
270
- ... :type param1: int
271
- ... :param param2: The second parameter.
272
- ... :type param2: str
273
- ... '''
274
- ... pass
275
- >>> description, params = _extract_docstring_details_rest(example_function)
276
- >>> description
277
- 'Example function.'
278
- >>> params == {'param1': 'The first parameter.', 'param2': 'The second parameter.'}
279
- True
273
+ >>> def example_function(param1: int, param2: str):
274
+ ... '''Example function.
275
+ ...
276
+ ... :param param1: The first parameter.
277
+ ... :type param1: int
278
+ ... :param param2: The second parameter.
279
+ ... :type param2: str
280
+ ... '''
281
+ ... pass
282
+ >>> description, params = _extract_docstring_details_rest(example_function)
283
+ >>> description
284
+ 'Example function.'
285
+ >>> params == {'param1': 'The first parameter.', 'param2': 'The second parameter.'}
286
+ True
280
287
  """
281
288
  docstring = inspect.getdoc(func)
282
289
  if not docstring:
@@ -295,7 +302,7 @@ class ParseUtil:
295
302
  params_description[param] = desc.strip()
296
303
  current_param = param
297
304
  elif line.startswith(" "):
298
- params_description[current_param] += " " + line
305
+ params_description[current_param] += f" {line}"
299
306
 
300
307
  return func_description, params_description
301
308
 
@@ -307,30 +314,30 @@ class ParseUtil:
307
314
  (reST) style format.
308
315
 
309
316
  Args:
310
- func (Callable): The function from which to extract docstring details.
311
- style (str): The style of docstring to parse ('google' or 'reST').
317
+ func (Callable): The function from which to extract docstring details.
318
+ style (str): The style of docstring to parse ('google' or 'reST').
312
319
 
313
320
  Returns:
314
- Tuple[str, Dict[str, str]]: A tuple containing the function description
315
- and a dictionary with parameter names as keys and their descriptions as values.
321
+ Tuple[str, Dict[str, str]]: A tuple containing the function description
322
+ and a dictionary with parameter names as keys and their descriptions as values.
316
323
 
317
324
  Raises:
318
- ValueError: If an unsupported style is provided.
325
+ ValueError: If an unsupported style is provided.
319
326
 
320
327
  Examples:
321
- >>> def example_function(param1: int, param2: str):
322
- ... '''Example function.
323
- ...
324
- ... Args:
325
- ... param1 (int): The first parameter.
326
- ... param2 (str): The second parameter.
327
- ... '''
328
- ... pass
329
- >>> description, params = _extract_docstring_details(example_function, style='google')
330
- >>> description
331
- 'Example function.'
332
- >>> params == {'param1': 'The first parameter.', 'param2': 'The second parameter.'}
333
- True
328
+ >>> def example_function(param1: int, param2: str):
329
+ ... '''Example function.
330
+ ...
331
+ ... Args:
332
+ ... param1 (int): The first parameter.
333
+ ... param2 (str): The second parameter.
334
+ ... '''
335
+ ... pass
336
+ >>> description, params = _extract_docstring_details(example_function, style='google')
337
+ >>> description
338
+ 'Example function.'
339
+ >>> params == {'param1': 'The first parameter.', 'param2': 'The second parameter.'}
340
+ True
334
341
  """
335
342
  if style == "google":
336
343
  func_description, params_description = (
@@ -352,16 +359,16 @@ class ParseUtil:
352
359
  Converts a Python type to its JSON type equivalent.
353
360
 
354
361
  Args:
355
- py_type (str): The name of the Python type.
362
+ py_type (str): The name of the Python type.
356
363
 
357
364
  Returns:
358
- str: The corresponding JSON type.
365
+ str: The corresponding JSON type.
359
366
 
360
367
  Examples:
361
- >>> _python_to_json_type('str')
362
- 'string'
363
- >>> _python_to_json_type('int')
364
- 'number'
368
+ >>> _python_to_json_type('str')
369
+ 'string'
370
+ >>> _python_to_json_type('int')
371
+ 'number'
365
372
  """
366
373
  type_mapping = {
367
374
  "str": "string",
@@ -381,24 +388,24 @@ class ParseUtil:
381
388
  docstrings. The schema includes the function's name, description, and parameters.
382
389
 
383
390
  Args:
384
- func (Callable): The function to generate a schema for.
385
- style (str): The docstring format ('google' or 'reST').
391
+ func (Callable): The function to generate a schema for.
392
+ style (str): The docstring format ('google' or 'reST').
386
393
 
387
394
  Returns:
388
- Dict[str, Any]: A schema describing the function.
395
+ Dict[str, Any]: A schema describing the function.
389
396
 
390
397
  Examples:
391
- >>> def example_function(param1: int, param2: str) -> bool:
392
- ... '''Example function.
393
- ...
394
- ... Args:
395
- ... param1 (int): The first parameter.
396
- ... param2 (str): The second parameter.
397
- ... '''
398
- ... return True
399
- >>> schema = _func_to_schema(example_function)
400
- >>> schema['function']['name']
401
- 'example_function'
398
+ >>> def example_function(param1: int, param2: str) -> bool:
399
+ ... '''Example function.
400
+ ...
401
+ ... Args:
402
+ ... param1 (int): The first parameter.
403
+ ... param2 (str): The second parameter.
404
+ ... '''
405
+ ... return True
406
+ >>> schema = _func_to_schema(example_function)
407
+ >>> schema['function']['name']
408
+ 'example_function'
402
409
  """
403
410
  # Extracting function name and docstring details
404
411
  func_name = func.__name__
@@ -432,8 +439,7 @@ class ParseUtil:
432
439
  "description": param_description,
433
440
  }
434
441
 
435
- # Constructing the schema
436
- schema = {
442
+ return {
437
443
  "type": "function",
438
444
  "function": {
439
445
  "name": func_name,
@@ -442,8 +448,6 @@ class ParseUtil:
442
448
  },
443
449
  }
444
450
 
445
- return schema
446
-
447
451
 
448
452
  class StringMatch:
449
453
 
@@ -457,16 +461,16 @@ class StringMatch:
457
461
  and 1 is an exact match.
458
462
 
459
463
  Args:
460
- s: The first string to compare.
461
- t: The second string to compare.
464
+ s: The first string to compare.
465
+ t: The second string to compare.
462
466
 
463
467
  Returns:
464
- A float representing the Jaro distance between the two strings, ranging from 0 to 1,
465
- where 1 means the strings are identical.
468
+ A float representing the Jaro distance between the two strings, ranging from 0 to 1,
469
+ where 1 means the strings are identical.
466
470
 
467
471
  Examples:
468
- >>> jaro_distance("martha", "marhta")
469
- 0.9444444444444445
472
+ >>> jaro_distance("martha", "marhta")
473
+ 0.9444444444444445
470
474
  """
471
475
  s_len = len(s)
472
476
  t_len = len(t)
@@ -521,18 +525,18 @@ class StringMatch:
521
525
  person names, and is designed to improve the scoring of strings that have a common prefix.
522
526
 
523
527
  Args:
524
- s: The first string to compare.
525
- t: The second string to compare.
526
- scaling: The scaling factor for how much the score is adjusted upwards for having common prefixes.
527
- The scaling factor should be less than 1, and a typical value is 0.1.
528
+ s: The first string to compare.
529
+ t: The second string to compare.
530
+ scaling: The scaling factor for how much the score is adjusted upwards for having common prefixes.
531
+ The scaling factor should be less than 1, and a typical value is 0.1.
528
532
 
529
533
  Returns:
530
- A float representing the Jaro-Winkler similarity between the two strings, ranging from 0 to 1,
531
- where 1 means the strings are identical.
534
+ A float representing the Jaro-Winkler similarity between the two strings, ranging from 0 to 1,
535
+ where 1 means the strings are identical.
532
536
 
533
537
  Examples:
534
- >>> jaro_winkler_similarity("dixon", "dicksonx")
535
- 0.8133333333333332
538
+ >>> jaro_winkler_similarity("dixon", "dicksonx")
539
+ 0.8133333333333332
536
540
  """
537
541
  jaro_sim = StringMatch.jaro_distance(s, t)
538
542
  prefix_len = 0
@@ -555,15 +559,15 @@ class StringMatch:
555
559
  required to change one word into the other. Each operation has an equal cost.
556
560
 
557
561
  Args:
558
- a: The first string to compare.
559
- b: The second string to compare.
562
+ a: The first string to compare.
563
+ b: The second string to compare.
560
564
 
561
565
  Returns:
562
- An integer representing the Levenshtein distance between the two strings.
566
+ An integer representing the Levenshtein distance between the two strings.
563
567
 
564
568
  Examples:
565
- >>> levenshtein_distance("kitten", "sitting")
566
- 3
569
+ >>> levenshtein_distance("kitten", "sitting")
570
+ 3
567
571
  """
568
572
  m, n = len(a), len(b)
569
573
  # Initialize 2D array (m+1) x (n+1)
@@ -576,17 +580,13 @@ class StringMatch:
576
580
  d[0][j] = j
577
581
 
578
582
  # Compute the distance
579
- for i in range(1, m + 1):
580
- for j in range(1, n + 1):
581
- if a[i - 1] == b[j - 1]:
582
- cost = 0
583
- else:
584
- cost = 1
585
- d[i][j] = min(
586
- d[i - 1][j] + 1, # deletion
587
- d[i][j - 1] + 1, # insertion
588
- d[i - 1][j - 1] + cost,
589
- ) # substitution
583
+ for i, j in itertools.product(range(1, m + 1), range(1, n + 1)):
584
+ cost = 0 if a[i - 1] == b[j - 1] else 1
585
+ d[i][j] = min(
586
+ d[i - 1][j] + 1, # deletion
587
+ d[i][j - 1] + 1, # insertion
588
+ d[i - 1][j - 1] + cost,
589
+ ) # substitution
590
590
  return d[m][n]
591
591
 
592
592
  @staticmethod
@@ -620,3 +620,20 @@ class StringMatch:
620
620
  corrected_out[k] = v
621
621
 
622
622
  return corrected_out
623
+
624
+ @staticmethod
625
+ def choose_most_similar(word, correct_words_list, score_func=None):
626
+
627
+ if score_func is None:
628
+ score_func = StringMatch.jaro_winkler_similarity
629
+
630
+ # Calculate Jaro-Winkler similarity scores for each potential match
631
+ scores = np.array(
632
+ [
633
+ score_func(convert.to_str(word), correct_word)
634
+ for correct_word in correct_words_list
635
+ ]
636
+ )
637
+ # Find the index of the highest score
638
+ max_score_index = np.argmax(scores)
639
+ return correct_words_list[max_score_index]