myokit 1.36.1__py3-none-any.whl → 1.37.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. myokit/__init__.py +6 -19
  2. myokit/_aux.py +4 -0
  3. myokit/_datablock.py +55 -65
  4. myokit/_datalog.py +42 -7
  5. myokit/_err.py +26 -3
  6. myokit/_expressions.py +241 -127
  7. myokit/_model_api.py +19 -13
  8. myokit/_myokit_version.py +1 -1
  9. myokit/_sim/jacobian.py +3 -3
  10. myokit/_sim/openclsim.py +5 -5
  11. myokit/_sim/rhs.py +1 -1
  12. myokit/formats/__init__.py +4 -9
  13. myokit/formats/ansic/_ewriter.py +4 -20
  14. myokit/formats/axon/_abf.py +11 -4
  15. myokit/formats/diffsl/__init__.py +60 -0
  16. myokit/formats/diffsl/_ewriter.py +145 -0
  17. myokit/formats/diffsl/_exporter.py +435 -0
  18. myokit/formats/heka/_patchmaster.py +345 -115
  19. myokit/formats/opencl/_ewriter.py +3 -42
  20. myokit/formats/opencl/template/minilog.py +1 -1
  21. myokit/formats/sympy/_ereader.py +2 -1
  22. myokit/formats/wcp/_wcp.py +3 -3
  23. myokit/gui/datalog_viewer.py +28 -9
  24. myokit/lib/markov.py +2 -2
  25. myokit/lib/plots.py +4 -4
  26. myokit/tests/data/formats/wcp-file-empty.wcp +0 -0
  27. myokit/tests/data/io/bad1d-2-no-header.zip +0 -0
  28. myokit/tests/data/io/bad1d-3-no-data.zip +0 -0
  29. myokit/tests/data/io/bad1d-4-not-a-zip.zip +1 -105
  30. myokit/tests/data/io/bad1d-5-bad-data-type.zip +0 -0
  31. myokit/tests/data/io/bad1d-6-time-too-short.zip +0 -0
  32. myokit/tests/data/io/bad1d-7-0d-too-short.zip +0 -0
  33. myokit/tests/data/io/bad1d-8-1d-too-short.zip +0 -0
  34. myokit/tests/data/io/bad2d-2-no-header.zip +0 -0
  35. myokit/tests/data/io/bad2d-3-no-data.zip +0 -0
  36. myokit/tests/data/io/bad2d-4-not-a-zip.zip +1 -105
  37. myokit/tests/data/io/bad2d-5-bad-data-type.zip +0 -0
  38. myokit/tests/data/io/bad2d-8-2d-too-short.zip +0 -0
  39. myokit/tests/data/io/block1d.mmt +187 -0
  40. myokit/tests/data/io/datalog-18-duplicate-keys.csv +4 -0
  41. myokit/tests/test_aux.py +4 -0
  42. myokit/tests/test_datablock.py +16 -16
  43. myokit/tests/test_datalog.py +24 -1
  44. myokit/tests/test_expressions.py +532 -251
  45. myokit/tests/test_formats_ansic.py +6 -18
  46. myokit/tests/test_formats_cpp.py +0 -5
  47. myokit/tests/test_formats_cuda.py +7 -15
  48. myokit/tests/test_formats_diffsl.py +728 -0
  49. myokit/tests/test_formats_easyml.py +4 -9
  50. myokit/tests/test_formats_exporters_run.py +3 -0
  51. myokit/tests/test_formats_latex.py +10 -11
  52. myokit/tests/test_formats_matlab.py +0 -8
  53. myokit/tests/test_formats_opencl.py +0 -29
  54. myokit/tests/test_formats_python.py +2 -19
  55. myokit/tests/test_formats_stan.py +0 -13
  56. myokit/tests/test_formats_sympy.py +3 -3
  57. myokit/tests/test_formats_wcp.py +15 -0
  58. myokit/tests/test_model.py +20 -20
  59. myokit/tests/test_parsing.py +19 -0
  60. {myokit-1.36.1.dist-info → myokit-1.37.1.dist-info}/METADATA +1 -1
  61. {myokit-1.36.1.dist-info → myokit-1.37.1.dist-info}/RECORD +65 -58
  62. {myokit-1.36.1.dist-info → myokit-1.37.1.dist-info}/LICENSE.txt +0 -0
  63. {myokit-1.36.1.dist-info → myokit-1.37.1.dist-info}/WHEEL +0 -0
  64. {myokit-1.36.1.dist-info → myokit-1.37.1.dist-info}/entry_points.txt +0 -0
  65. {myokit-1.36.1.dist-info → myokit-1.37.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,728 @@
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Tests the DiffSL module.
4
+ #
5
+ # This file is part of Myokit.
6
+ # See http://myokit.org for copyright, sharing, and licensing details.
7
+ #
8
+ import itertools
9
+ import unittest
10
+
11
+ import myokit
12
+ import myokit.formats
13
+ import myokit.formats.diffsl
14
+ import myokit.tests
15
+ from myokit import (Abs, ACos, And, ASin, ATan, Ceil, Cos, Divide, Equal, Exp,
16
+ Floor, If, Less, LessEqual, Log, Log10, Minus, More,
17
+ MoreEqual, Multiply, Not, NotEqual, Number, Or, Piecewise,
18
+ Plus, Power, PrefixMinus, PrefixPlus, Quotient, Remainder,
19
+ Sin, Sqrt, Tan)
20
+ from myokit.tests import TemporaryDirectory, WarningCollector
21
+
22
+ # Model that requires unit conversion
23
+ units_model = """
24
+ [[model]]
25
+ mb.V = -0.08
26
+ hh.x = 0.1
27
+ hh.y = 0.9
28
+ mm.C = 0.9
29
+
30
+ [engine]
31
+ pace = 0 bind pace
32
+ time = 0 [s]
33
+ in [s]
34
+ bind time
35
+
36
+ [ideal]
37
+ Vc = engine.pace * 1
38
+
39
+ [mb]
40
+ dot(V) = (hh.I1 + mm.I2 + tt.I3) / C
41
+ in [V]
42
+ C = 20 [pF]
43
+ in [pF]
44
+
45
+ [hh]
46
+ dot(x) = (inf - x) / tau
47
+ inf = 0.8
48
+ tau = 3 [s]
49
+ in [s]
50
+ dot(y) = alpha * (1 - y) - beta * y
51
+ alpha = 0.1 [1/s]
52
+ in [1/s]
53
+ beta = 0.2 [1/s]
54
+ in [1/s]
55
+ I1 = 3 [pS] * x * y * (mb.V - 0.05 [V])
56
+ in [pA]
57
+
58
+ [mm]
59
+ dot_C = beta * O - alpha * C
60
+ dot(C) = dot_C
61
+ alpha = 0.3 [1/s]
62
+ in [1/s]
63
+ beta = 0.4 [1/s]
64
+ in [1/s]
65
+ O = 1 - C
66
+ I2 = 2 [pS] * O * (mb.V + 0.02 [V])
67
+ in [pA]
68
+
69
+ [tt]
70
+ I3 = 4 [pS] * mm.O * (mb.V + 0.02 [V])
71
+ tTest = 0
72
+ t_Test = 1
73
+ t_test = 2
74
+ """
75
+
76
+ units_output = """
77
+ /*
78
+ This file was generated by Myokit.
79
+ */
80
+
81
+ /* Input parameters */
82
+ /* E.g. in = [ varZero, varOne, varTwo ] */
83
+ in = [ ]
84
+
85
+ /* Engine: pace */
86
+ /* E.g.
87
+ -80 * (1 - sigmoid((t-100)*5000))
88
+ -120 * (sigmoid((t-100)*5000) - sigmoid((t-200)*5000))
89
+ */
90
+ enginePace { 0.0 } /* engine.pace */
91
+
92
+ /* Constants: hh */
93
+ hhXInf { 0.8 } /* hh.x.inf */
94
+ hhXTau { 3.0 } /* hh.x.tau [s] */
95
+ hhYAlpha { 0.1 } /* hh.y.alpha [S/F] */
96
+ hhYBeta { 0.2 } /* hh.y.beta [S/F] */
97
+
98
+ /* Constants: mm */
99
+ mmAlpha { 0.3 } /* mm.alpha [S/F] */
100
+ mmBeta { 0.4 } /* mm.beta [S/F] */
101
+
102
+ /* Constants: tt */
103
+ ttTTest1 { 0.0 } /* tt.tTest */
104
+ ttTTest2 { 1.0 } /* tt.t_Test */
105
+ ttTTest3 { 2.0 } /* tt.t_test */
106
+
107
+ /* Constants: mb */
108
+ mbC { 20.0 } /* mb.C [pF] */
109
+
110
+ /* Initial conditions */
111
+ u_i {
112
+ mbV = -0.08 * 1000 [1 (0.001)], /* mb.V [mV] */
113
+ hhX = 0.1, /* hh.x */
114
+ hhY = 0.9, /* hh.y */
115
+ mmC = 0.9, /* mm.C */
116
+ }
117
+
118
+ dudt_i {
119
+ diffMbV = 0,
120
+ diffHhX = 0,
121
+ diffHhY = 0,
122
+ diffMmC = 0,
123
+ }
124
+
125
+ /* Variables: hh */
126
+ hhI1 { 3.0 * hhX * hhY * (mbV / 1000.0 - 0.05) * 0.05 } /* hh.I1 [A/F] */
127
+
128
+ /* Variables: mm */
129
+ mmO { 1.0 - mmC } /* mm.O */
130
+ mmI2 { 2.0 * mmO * (mbV / 1000.0 + 0.02) * 0.05 } /* mm.I2 [A/F] */
131
+ mmDotC { mmBeta * mmO - mmAlpha * mmC } /* mm.dot_C */
132
+
133
+ /* Variables: ideal */
134
+ idealVc { enginePace * 1.0 } /* ideal.Vc */
135
+
136
+ /* Variables: tt */
137
+ ttI3 { 4.0 * mmO * (mbV / 1000.0 + 0.02) } /* tt.I3 */
138
+
139
+ /* Solve */
140
+ F_i {
141
+ diffMbV,
142
+ diffHhX,
143
+ diffHhY,
144
+ diffMmC,
145
+ }
146
+
147
+ G_i {
148
+ (hhI1 / 0.05 + mmI2 / 0.05 + ttI3) / mbC * 1000.0 / 1000.0,
149
+ (hhXInf - hhX) / hhXTau / 1000.0,
150
+ (hhYAlpha * (1.0 - hhY) - hhYBeta * hhY) / 1000.0,
151
+ mmDotC / 1000.0,
152
+ }
153
+
154
+ /* Output */
155
+ out_i {
156
+ hhX,
157
+ hhY,
158
+ mbV,
159
+ mmC,
160
+ }
161
+ """
162
+
163
+
164
+ class DiffSLExporterTest(unittest.TestCase):
165
+ """Tests DiffSL export."""
166
+
167
+ def test_diffsl_exporter(self):
168
+ # Tests exporting a model
169
+
170
+ model = myokit.load_model('example')
171
+ with TemporaryDirectory() as d:
172
+ path = d.path('diffsl.model')
173
+
174
+ e = myokit.formats.diffsl.DiffSLExporter()
175
+
176
+ # Test with simple model
177
+ e.model(path, model)
178
+
179
+ # Test with protocol set
180
+ with self.assertRaisesRegex(ValueError, 'input protocol'):
181
+ e.model(path, model, protocol='-80 + 120*heaviside(t-10)')
182
+
183
+ # Test with extra bound variables
184
+ model.get('membrane.C').set_binding('hello')
185
+ e.model(path, model)
186
+
187
+ # Test without V being a state variable
188
+ v = model.get('membrane.V')
189
+ v.demote()
190
+ v.set_rhs(3)
191
+ e.model(path, model)
192
+
193
+ # Test with explicit time dependence
194
+ t0 = model.get('membrane').add_variable('t0')
195
+ t0.set_rhs('0 + engine.time')
196
+ with self.assertRaisesRegex(myokit.ExportError, 'time dependence'):
197
+ e.model(path, model)
198
+
199
+ # Test with invalid model
200
+ v.set_rhs('2 * V')
201
+ with self.assertRaisesRegex(myokit.ExportError, 'valid model'):
202
+ e.model(path, model)
203
+
204
+ def test_unit_conversion(self):
205
+ # Tests exporting a model that requires unit conversion
206
+
207
+ # Export model
208
+ m = myokit.parse_model(units_model)
209
+ e = myokit.formats.diffsl.DiffSLExporter()
210
+ with TemporaryDirectory() as d:
211
+ path = d.path('diffsl.model')
212
+ e.model(path, m)
213
+ with open(path, 'r') as f:
214
+ observed = f.read().strip().splitlines()
215
+
216
+ # Get expected output
217
+ expected = units_output.strip().splitlines()
218
+
219
+ # Compare (line by line, for readable output)
220
+ for ob, ex in zip(observed, expected):
221
+ self.assertEqual(ob, ex)
222
+ self.assertEqual(len(observed), len(expected))
223
+
224
+ # Test warnings are raised if conversion fails
225
+ m.get('mb.V').set_rhs('hh.I1 + mm.I2')
226
+ m.get('mb').remove_variable(m.get('mb.C'))
227
+ with TemporaryDirectory() as d:
228
+ path = d.path('diffsl.model')
229
+ with WarningCollector() as c:
230
+ e.model(path, m)
231
+ self.assertIn('Unable to convert hh.I1', c.text())
232
+ self.assertIn('Unable to convert mm.I2', c.text())
233
+
234
+ m.get('engine.time').set_unit(myokit.units.cm)
235
+ with TemporaryDirectory() as d:
236
+ path = d.path('diffsl.model')
237
+ with WarningCollector() as c:
238
+ e.model(path, m)
239
+ self.assertIn('Unable to convert engine.time', c.text())
240
+
241
+ def test_diffsl_exporter_fetching(self):
242
+ # Tests getting an DiffSL exporter via the 'exporter' interface
243
+
244
+ e = myokit.formats.exporter('diffsl')
245
+ self.assertIsInstance(e, myokit.formats.diffsl.DiffSLExporter)
246
+
247
+ def test_capability_reporting(self):
248
+ # Tests if the correct capabilities are reported
249
+ e = myokit.formats.diffsl.DiffSLExporter()
250
+ self.assertTrue(e.supports_model())
251
+
252
+
253
+ class DiffSLExpressionWriterTest(myokit.tests.ExpressionWriterTestCase):
254
+ """Test conversion to DiffSL syntax."""
255
+
256
+ _name = 'diffsl'
257
+ _target = myokit.formats.diffsl.DiffSLExpressionWriter
258
+
259
+ def test_number(self):
260
+ self.eq(Number(1), '1.0')
261
+ self.eq(Number(-2), '-2.0')
262
+ self.eq(Number(13, 'mV'), '13.0')
263
+
264
+ def test_name(self):
265
+ # Inherited from CBasedExpressionWriter
266
+ self.eq(self.a, 'a')
267
+ w = self._target()
268
+ w.set_lhs_function(lambda v: v.var().qname().upper())
269
+ self.assertEqual(w.ex(self.a), 'COMP.A')
270
+
271
+ def test_derivative(self):
272
+ # Inherited from CBasedExpressionWriter
273
+ self.eq(myokit.Derivative(self.a), 'dot(a)')
274
+
275
+ def test_partial_derivative(self):
276
+ e = myokit.PartialDerivative(self.a, self.b)
277
+ self.assertRaisesRegex(NotImplementedError, 'Partial', self.w.ex, e)
278
+
279
+ def test_initial_value(self):
280
+ e = myokit.InitialValue(self.a)
281
+ self.assertRaisesRegex(NotImplementedError, 'Initial', self.w.ex, e)
282
+
283
+ def test_prefix_plus_minus(self):
284
+ # Inherited from CBasedExpressionWriter
285
+ p = Number(11, 'kV')
286
+ a, b, c = self.abc
287
+ self.eq(PrefixPlus(p), '+11.0')
288
+ self.eq(PrefixPlus(PrefixPlus(PrefixPlus(p))), '+(+(+11.0))')
289
+ self.eq(Divide(PrefixPlus(Plus(a, b)), c), '+(a + b) / c')
290
+ self.eq(PrefixMinus(p), '-11.0')
291
+ self.eq(PrefixMinus(PrefixMinus(p)), '-(-11.0)')
292
+ self.eq(PrefixMinus(Number(-1)), '-(-1.0)')
293
+ self.eq(PrefixMinus(Minus(a, b)), '-(a - b)')
294
+ self.eq(Multiply(PrefixMinus(Plus(b, a)), c), '-(b + a) * c')
295
+ self.eq(PrefixMinus(Divide(b, a)), '-(b / a)')
296
+
297
+ def test_plus_minus(self):
298
+ a, b, c = self.abc
299
+ self.eq(Plus(a, b), 'a + b')
300
+ self.eq(Plus(Plus(a, b), c), 'a + b + c')
301
+ self.eq(Plus(a, Plus(b, c)), 'a + (b + c)')
302
+
303
+ self.eq(Minus(a, b), 'a - b')
304
+ self.eq(Minus(Minus(a, b), c), 'a - b - c')
305
+ self.eq(Minus(a, Minus(b, c)), 'a - (b - c)')
306
+
307
+ self.eq(Minus(a, b), 'a - b')
308
+ self.eq(Plus(Minus(a, b), c), 'a - b + c')
309
+ self.eq(Minus(a, Plus(b, c)), 'a - (b + c)')
310
+ self.eq(Minus(Plus(a, b), c), 'a + b - c')
311
+ self.eq(Minus(a, Plus(b, c)), 'a - (b + c)')
312
+
313
+ # No expm in DiffSL
314
+ self.eq(Minus(Exp(Number(2)), Number(1)), 'exp(2.0) - 1.0')
315
+ self.eq(Minus(Number(1), Exp(Number(3))), '1.0 - exp(3.0)')
316
+
317
+ def test_multiply_divide(self):
318
+ # Inherited from CBasedExpressionWriter
319
+ a, b, c = self.abc
320
+ self.eq(Multiply(a, b), 'a * b')
321
+ self.eq(Multiply(Multiply(a, b), c), 'a * b * c')
322
+ self.eq(Multiply(a, Multiply(b, c)), 'a * (b * c)')
323
+ self.eq(Divide(a, b), 'a / b')
324
+ self.eq(Divide(Divide(a, b), c), 'a / b / c')
325
+ self.eq(Divide(a, Divide(b, c)), 'a / (b / c)')
326
+
327
+ def test_quotient(self):
328
+ # Inherited from CBasedExpressionWriter
329
+ a, b, c = self.abc
330
+ with WarningCollector():
331
+ self.eq(Quotient(a, b), 'floor(a / b)')
332
+ self.eq(Quotient(Plus(a, c), b), 'floor((a + c) / b)')
333
+ self.eq(Quotient(Divide(a, c), b), 'floor(a / c / b)')
334
+ self.eq(Quotient(a, Divide(b, c)), 'floor(a / (b / c))')
335
+ self.eq(Multiply(Quotient(a, b), c), 'floor(a / b) * c')
336
+ self.eq(Multiply(c, Quotient(a, b)), 'c * (floor(a / b))')
337
+
338
+ def test_remainder(self):
339
+ # Inherited from CBasedExpressionWriter
340
+ a, b, c = self.abc
341
+ with WarningCollector():
342
+ self.eq(Remainder(a, b), '(a - b * floor(a / b))')
343
+ self.eq(
344
+ Remainder(Plus(a, c), b), '(a + c - b * floor((a + c) / b))'
345
+ )
346
+ self.eq(Multiply(Remainder(a, b), c), '(a - b * floor(a / b)) * c')
347
+ self.eq(Divide(c, Remainder(b, a)), 'c / ((b - a * floor(b / a)))')
348
+
349
+ def test_power(self):
350
+ # Inherited from CBasedExpressionWriter
351
+ a, b, c = self.abc
352
+ self.eq(Power(a, b), 'pow(a, b)')
353
+ self.eq(Power(Power(a, b), c), 'pow(pow(a, b), c)')
354
+ self.eq(Power(a, Power(b, c)), 'pow(a, pow(b, c))')
355
+
356
+ def test_log(self):
357
+ # Inherited from CBasedExpressionWriter
358
+ a, b = self.ab
359
+ self.eq(Log(a), 'log(a)')
360
+ self.eq(Log10(a), '(log(a) / log(10.0))')
361
+ self.eq(Log(a, b), '(log(a) / log(b))')
362
+
363
+ def test_supported_functions(self):
364
+ a = self.a
365
+
366
+ self.eq(Abs(a), 'abs(a)')
367
+ self.eq(Cos(a), 'cos(a)')
368
+ self.eq(Exp(a), 'exp(a)')
369
+ self.eq(Log(a), 'log(a)')
370
+ self.eq(Sin(a), 'sin(a)')
371
+ self.eq(Sqrt(a), 'sqrt(a)')
372
+ self.eq(Tan(a), 'tan(a)')
373
+
374
+ def test_unsupported_functions(self):
375
+ a = self.a
376
+
377
+ with WarningCollector() as wc:
378
+ self.eq(ACos(a), 'acos(a)')
379
+ self.assertIn('Unsupported', wc.text())
380
+
381
+ with WarningCollector() as wc:
382
+ self.eq(ASin(a), 'asin(a)')
383
+ self.assertIn('Unsupported', wc.text())
384
+
385
+ with WarningCollector() as wc:
386
+ self.eq(ATan(a), 'atan(a)')
387
+ self.assertIn('Unsupported', wc.text())
388
+
389
+ with WarningCollector() as wc:
390
+ self.eq(Ceil(a), 'ceil(a)')
391
+ self.assertIn('Unsupported', wc.text())
392
+
393
+ with WarningCollector() as wc:
394
+ self.eq(Floor(a), 'floor(a)')
395
+ self.assertIn('Unsupported', wc.text())
396
+
397
+ def test_conditional_operators(self):
398
+ a, b, c, d = self.abcd
399
+
400
+ self.eq(Equal(a, b), 'heaviside(a - b) * heaviside(b - a)')
401
+
402
+ self.eq(Less(a, b), '(1 - heaviside(a - b))')
403
+
404
+ self.eq(LessEqual(a, b), 'heaviside(b - a)')
405
+
406
+ self.eq(More(a, b), '(1 - heaviside(b - a))')
407
+
408
+ self.eq(MoreEqual(a, b), 'heaviside(a - b)')
409
+
410
+ self.eq(NotEqual(a, b), '(1 - heaviside(a - b) * heaviside(b - a))')
411
+
412
+ self.eq(
413
+ Not(NotEqual(a, b)),
414
+ '(1 - (1 - heaviside(a - b) * heaviside(b - a)))',
415
+ )
416
+
417
+ self.eq(
418
+ Not(Not(Equal(a, b))),
419
+ '(1 - (1 - heaviside(a - b) * heaviside(b - a)))',
420
+ )
421
+
422
+ self.eq(
423
+ And(Equal(a, b), NotEqual(c, d)),
424
+ 'heaviside(a - b) * heaviside(b - a)'
425
+ ' * (1 - heaviside(c - d) * heaviside(d - c))',
426
+ )
427
+
428
+ self.eq(
429
+ Or(More(d, c), MoreEqual(b, a)),
430
+ '(1 - (1 - (1 - heaviside(c - d))) * (1 - heaviside(b - a)))',
431
+ )
432
+
433
+ self.eq(
434
+ Or(Less(d, c), LessEqual(b, a)),
435
+ '(1 - (1 - (1 - heaviside(d - c))) * (1 - heaviside(a - b)))',
436
+ )
437
+
438
+ self.eq(
439
+ Not(Or(Equal(Number(1), Number(2)), Equal(Number(3), Number(4)))),
440
+ '(1 - (1 - (1 - heaviside(1.0 - 2.0) * heaviside(2.0 - 1.0)) '
441
+ '* (1 - heaviside(3.0 - 4.0) * heaviside(4.0 - 3.0))))',
442
+ )
443
+
444
+ self.eq(
445
+ Not(Less(Number(1), Number(2))), '(1 - (1 - heaviside(1.0 - 2.0)))'
446
+ )
447
+
448
+ def test_if_expressions(self):
449
+ a, b, c, d = self.abcd
450
+
451
+ self.eq(
452
+ If(Equal(a, b), c, d),
453
+ '(heaviside(a - b) * heaviside(b - a) * c'
454
+ ' + (1 - heaviside(a - b) * heaviside(b - a)) * d)',
455
+ )
456
+
457
+ self.eq(
458
+ If(Equal(a, b), c, Number(0)),
459
+ '(heaviside(a - b) * heaviside(b - a) * c'
460
+ ' + (1 - heaviside(a - b) * heaviside(b - a)) * 0.0)',
461
+ )
462
+
463
+ self.eq(
464
+ If(Equal(a, b), Number(0), d),
465
+ '(heaviside(a - b) * heaviside(b - a) * 0.0'
466
+ ' + (1 - heaviside(a - b) * heaviside(b - a)) * d)',
467
+ )
468
+
469
+ self.eq(
470
+ If(Equal(a, b), c, Number(1)),
471
+ '(heaviside(a - b) * heaviside(b - a) * c'
472
+ ' + (1 - heaviside(a - b) * heaviside(b - a)) * 1.0)',
473
+ )
474
+
475
+ self.eq(
476
+ If(Equal(a, b), Number(1), d),
477
+ '(heaviside(a - b) * heaviside(b - a) * 1.0'
478
+ ' + (1 - heaviside(a - b) * heaviside(b - a)) * d)',
479
+ )
480
+
481
+ self.eq(
482
+ If(NotEqual(a, b), c, d),
483
+ '((1 - heaviside(a - b) * heaviside(b - a)) * c'
484
+ ' + (1 - (1 - heaviside(a - b) * heaviside(b - a))) * d)',
485
+ )
486
+
487
+ self.eq(
488
+ If(More(a, b), c, d),
489
+ '((1 - heaviside(b - a)) * c + (1 - (1 - heaviside(b - a))) * d)',
490
+ )
491
+
492
+ self.eq(
493
+ If(MoreEqual(a, b), c, d),
494
+ '(heaviside(a - b) * c + (1 - heaviside(a - b)) * d)',
495
+ )
496
+
497
+ self.eq(
498
+ If(Less(a, b), c, d),
499
+ '((1 - heaviside(a - b)) * c + (1 - (1 - heaviside(a - b))) * d)',
500
+ )
501
+
502
+ self.eq(
503
+ If(LessEqual(a, b), c, d),
504
+ '(heaviside(b - a) * c + (1 - heaviside(b - a)) * d)',
505
+ )
506
+
507
+ def test_piecewise_expressions(self):
508
+ a, b, c, d = self.abcd
509
+
510
+ self.eq(Piecewise(Equal(a, b), c, d), self.w.ex(If(Equal(a, b), c, d)))
511
+
512
+ self.eq(
513
+ Piecewise(NotEqual(a, b), c, d),
514
+ self.w.ex(
515
+ If(NotEqual(a, b), c, d),
516
+ ),
517
+ )
518
+
519
+ self.eq(
520
+ Piecewise(More(a, b), c, d),
521
+ self.w.ex(
522
+ If(More(a, b), c, d),
523
+ ),
524
+ )
525
+
526
+ self.eq(
527
+ Piecewise(MoreEqual(a, b), c, d),
528
+ self.w.ex(If(MoreEqual(a, b), c, d)),
529
+ )
530
+
531
+ self.eq(Piecewise(Less(a, b), c, d), self.w.ex(If(Less(a, b), c, d)))
532
+
533
+ self.eq(
534
+ Piecewise(LessEqual(a, b), c, d),
535
+ self.w.ex(
536
+ If(LessEqual(a, b), c, d),
537
+ ),
538
+ )
539
+
540
+ self.eq(
541
+ Piecewise(Equal(a, b), c, Equal(a, d), Number(3), Number(4)),
542
+ self.w.ex(
543
+ If(Equal(a, b), c, If(Equal(a, d), Number(3), Number(4)))
544
+ ),
545
+ )
546
+
547
+ self.eq(
548
+ Piecewise(Less(a, b), Number(0), Less(c, d), Number(0), Number(5)),
549
+ '((1 - heaviside(a - b)) * 0.0 '
550
+ '+ (1 - (1 - heaviside(a - b))) * ((1 - heaviside(c - d)) * 0.0 '
551
+ '+ (1 - (1 - heaviside(c - d))) * 5.0))',
552
+ )
553
+
554
+ def test_heaviside_numerical(self):
555
+ """Test generated heaviside expressions with numerical values"""
556
+
557
+ def heaviside(x):
558
+ return 1 if x >= 0 else 0
559
+
560
+ values = itertools.product(
561
+ [-10e9, -1, -1e-9, 0, 1e-9, 1, 10e9], repeat=4
562
+ )
563
+
564
+ for a, b, c, d in values:
565
+ # a == b
566
+ result = int(a == b)
567
+ expr = self.w.ex(Equal(Number(a), Number(b)))
568
+ self.assertEqual(eval(expr), result)
569
+
570
+ # a < b
571
+ result = int(a < b)
572
+ expr = self.w.ex(Less(Number(a), Number(b)))
573
+ self.assertEqual(eval(expr), result)
574
+
575
+ # a <= b
576
+ result = int(a <= b)
577
+ expr = self.w.ex(LessEqual(Number(a), Number(b)))
578
+ self.assertEqual(eval(expr), result)
579
+
580
+ # a > b
581
+ result = int(a > b)
582
+ expr = self.w.ex(More(Number(a), Number(b)))
583
+ self.assertEqual(eval(expr), result)
584
+
585
+ # a >= b
586
+ result = int(a >= b)
587
+ expr = self.w.ex(MoreEqual(Number(a), Number(b)))
588
+ self.assertEqual(eval(expr), result)
589
+
590
+ # a != b
591
+ result = int(a != b)
592
+ expr = self.w.ex(NotEqual(Number(a), Number(b)))
593
+ self.assertEqual(eval(expr), result)
594
+
595
+ # not(a != b)
596
+ result = int(not (a != b))
597
+ expr = self.w.ex(Not(NotEqual(Number(a), Number(b))))
598
+ self.assertEqual(eval(expr), result)
599
+
600
+ # not(not(a == b))
601
+ result = int(not (not (a == b)))
602
+ expr = self.w.ex(Not(Not(Equal(Number(a), Number(b)))))
603
+ self.assertEqual(eval(expr), result)
604
+
605
+ # (a == b) and (c != d)
606
+ result = int((a == b) and (c != d))
607
+ expr = self.w.ex(
608
+ And(
609
+ Equal(Number(a), Number(b)), NotEqual(Number(c), Number(d))
610
+ )
611
+ )
612
+ self.assertEqual(eval(expr), result)
613
+
614
+ # (d > c) or (b >= a)
615
+ result = int((d > c) or (b >= a))
616
+ expr = self.w.ex(
617
+ Or(More(Number(d), Number(c)), MoreEqual(Number(b), Number(a)))
618
+ )
619
+ self.assertEqual(eval(expr), result)
620
+
621
+ # (d < c) or (b <= a)
622
+ result = int((d < c) or (b <= a))
623
+ expr = self.w.ex(
624
+ Or(Less(Number(d), Number(c)), LessEqual(Number(b), Number(a)))
625
+ )
626
+ self.assertEqual(eval(expr), result)
627
+
628
+ # (a == b) or (c == d)
629
+ result = int((a == b) or (c == d))
630
+ expr = self.w.ex(
631
+ Or(Equal(Number(a), Number(b)), Equal(Number(c), Number(d)))
632
+ )
633
+ self.assertEqual(eval(expr), result)
634
+
635
+ # not(a < b)
636
+ result = int(not (a < b))
637
+ expr = self.w.ex(Not(Less(Number(a), Number(b))))
638
+ self.assertEqual(eval(expr), result)
639
+
640
+ # if(a > b, c, d)
641
+ result = c if (a > b) else d
642
+ expr = self.w.ex(
643
+ If(More(Number(a), Number(b)), Number(c), Number(d))
644
+ )
645
+ self.assertEqual(eval(expr), result)
646
+
647
+ # if(a >= b, c, d)
648
+ result = c if (a >= b) else d
649
+ expr = self.w.ex(
650
+ If(MoreEqual(Number(a), Number(b)), Number(c), Number(d))
651
+ )
652
+ self.assertEqual(eval(expr), result)
653
+
654
+ # if(a < b, c, d)
655
+ result = c if (a < b) else d
656
+ expr = self.w.ex(
657
+ If(Less(Number(a), Number(b)), Number(c), Number(d))
658
+ )
659
+ self.assertEqual(eval(expr), result)
660
+
661
+ # if(a <= b, c, d)
662
+ result = c if (a <= b) else d
663
+ expr = self.w.ex(
664
+ If(LessEqual(Number(a), Number(b)), Number(c), Number(d))
665
+ )
666
+ self.assertEqual(eval(expr), result)
667
+
668
+ # piecewise(a > b, c, d)
669
+ result = c if (a > b) else d
670
+ expr = self.w.ex(
671
+ Piecewise(More(Number(a), Number(b)), Number(c), Number(d))
672
+ )
673
+ self.assertEqual(eval(expr), result)
674
+
675
+ # piecewise(a >= b, c, d)
676
+ result = c if (a >= b) else d
677
+ expr = self.w.ex(
678
+ Piecewise(
679
+ MoreEqual(Number(a), Number(b)), Number(c), Number(d)
680
+ )
681
+ )
682
+ self.assertEqual(eval(expr), result)
683
+
684
+ # piecewise(a < b, c, d)
685
+ result = c if (a < b) else d
686
+ expr = self.w.ex(
687
+ Piecewise(Less(Number(a), Number(b)), Number(c), Number(d))
688
+ )
689
+ self.assertEqual(eval(expr), result)
690
+
691
+ # piecewise(a <= b, c, d)
692
+ result = c if (a <= b) else d
693
+ expr = self.w.ex(
694
+ Piecewise(
695
+ LessEqual(Number(a), Number(b)), Number(c), Number(d)
696
+ )
697
+ )
698
+ self.assertEqual(eval(expr), result)
699
+
700
+ # piecewise(a == b, c, a == d, 3, 4)
701
+ result = c if (a == b) else (3 if (a == d) else 4)
702
+ expr = self.w.ex(
703
+ Piecewise(
704
+ Equal(Number(a), Number(b)),
705
+ Number(c),
706
+ Equal(Number(a), Number(d)),
707
+ Number(3),
708
+ Number(4),
709
+ )
710
+ )
711
+ self.assertEqual(eval(expr), result)
712
+
713
+ # piecewise(a < b, 0, c < d, 0, 5)
714
+ result = 0 if (a < b) else (0 if (c < d) else 5)
715
+ expr = self.w.ex(
716
+ Piecewise(
717
+ Less(Number(a), Number(b)),
718
+ Number(0),
719
+ Less(Number(c), Number(d)),
720
+ Number(0),
721
+ Number(5),
722
+ ),
723
+ )
724
+ self.assertEqual(eval(expr), result)
725
+
726
+
727
+ if __name__ == '__main__':
728
+ unittest.main()