Skip to content

Commit d1fa6ce

Browse files
authored
Complex cosine distance (#1211)
Switch from using Python's multiply and divide, and abs() of complex numbers to SymPy's Expr operators for more precise real-number results
1 parent f7900f0 commit d1fa6ce

File tree

4 files changed

+27
-16
lines changed

4 files changed

+27
-16
lines changed

mathics/builtin/distance/numeric.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class CosineDistance(Builtin):
154154
<dd>returns the angular cosine distance between vectors $u$ and $v$.
155155
</dl>
156156
157-
The cosine distance is equivalent to 1 - $u$.Conjugate[$v$] / ('Norm[$u$] Norm[$v$]').
157+
The cosine distance is equivalent to 1 - ($u$.Conjugate[$v$]) / ('Norm[$u$] Norm[$v$]').
158158
159159
>> N[CosineDistance[{7, 9}, {71, 89}]]
160160
= 0.0000759646
@@ -173,6 +173,12 @@ class CosineDistance(Builtin):
173173
Cosine distance includes a dot product scaled by norms:
174174
>> CosineDistance[{a, b, c}, {x, y, z}]
175175
= 1 + (-a Conjugate[x] - b Conjugate[y] - c Conjugate[z]) / (Sqrt[Abs[a] ^ 2 + Abs[b] ^ 2 + Abs[c] ^ 2] Sqrt[Abs[x] ^ 2 + Abs[y] ^ 2 + Abs[z] ^ 2])
176+
177+
A Cosine distance applied to complex numbers, uses 'Abs[]' for 'Norm[]' and complex multiplication for dot product,
178+
1 - $u$ * Conjugate[$v$] / ('Abs[$u$] Abs[$v$]'):
179+
180+
>> CosineDistance[1+2I, 5]
181+
= 1 - (1 / 5 + 2 I / 5) Sqrt[5]
176182
"""
177183

178184
messages = {

mathics/builtin/numeric.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
)
3333
from mathics.core.builtin import Builtin, MPMathFunction, SympyFunction
3434
from mathics.core.convert.sympy import from_sympy
35-
from mathics.core.element import BaseElement
3635
from mathics.core.evaluation import Evaluation
3736
from mathics.core.expression import Expression
3837
from mathics.core.number import MACHINE_EPSILON

mathics/eval/distance/numeric.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from sympy.core.power import Mul, Pow
2+
13
from mathics.core.atoms import Complex, Integer, Integer0, Real
2-
from mathics.core.convert.python import from_python
34
from mathics.core.convert.sympy import from_sympy, to_sympy_matrix
5+
from mathics.eval.arithmetic import eval_Abs
46

57

68
def eval_CosineDistance(u, v):
@@ -14,18 +16,22 @@ def eval_CosineDistance(u, v):
1416
if isinstance(u, (Complex, Integer, Real)) and isinstance(
1517
v, (Complex, Integer, Real)
1618
):
17-
u_val = u.to_python()
18-
v_val = v.to_python()
19-
distance = 1 - u_val * v_val.conjugate() / (abs(u_val) * abs(v_val))
20-
21-
# If the input arguments were Integers, preserve that in the result
22-
if isinstance(u_val, int) and isinstance(v_val, int):
23-
try:
24-
if distance == int(distance):
25-
distance = int(distance)
26-
except Exception:
27-
pass
28-
return from_python(distance)
19+
u_abs = eval_Abs(u)
20+
if u_abs is None:
21+
return
22+
v_abs = eval_Abs(v)
23+
if v_abs is None:
24+
return
25+
26+
# Do the following, but using SymPy expressions:
27+
# distance = 1 - (u * v.conjugate()) / (abs(u) * abs(v))
28+
numerator = Mul(u.to_sympy(), v.to_sympy().conjugate())
29+
divisor_product = Mul(u_abs.to_sympy(), v_abs.to_sympy())
30+
distance = 1 - numerator * Pow(divisor_product, -1)
31+
result = from_sympy(distance)
32+
if (isinstance(u, Real) or isinstance(v, Real)) and isinstance(result, Integer):
33+
result = Real(result.value)
34+
return result
2935

3036
sym_u = to_sympy_matrix(u)
3137
if sym_u is None:

test/builtin/distance/test_numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_cosine_distance():
1515
"CosineDistance of an Integers and a Real is 0.0",
1616
None,
1717
),
18-
("CosineDistance[Complex[1, 0], Complex[0, 2]]", "1. + 1 I", None, None),
18+
("CosineDistance[Complex[1, 0], Complex[0, 2]]", "1 + I", None, None),
1919
(
2020
"CosineDistance[Complex[5.0, 0], Complex[10, 3]]",
2121
"0.0421737 + 0.287348 I",

0 commit comments

Comments
 (0)