using System;
using System.Numerics;
using System.Runtime.CompilerServices;
namespace ZeroLevel.HNSW
{
///
/// Calculates cosine similarity.
///
///
/// Intuition behind selecting float as a carrier.
///
/// 1. In practice we work with vectors of dimensionality 100 and each component has value in range [-1; 1]
/// There certainly is a possibility of underflow.
/// But we assume that such cases are rare and we can rely on such underflow losses.
///
/// 2. According to the article http://www.ti3.tuhh.de/paper/rump/JeaRu13.pdf
/// the floating point rounding error is less then 100 * 2^-24 * sqrt(100) * sqrt(100) < 0.0005960
/// We deem such precision is satisfactory for out needs.
///
public static class CosineDistance
{
///
/// Calculates cosine distance without making any optimizations.
///
/// Left vector.
/// Right vector.
/// Cosine distance between u and v.
public static float NonOptimized(float[] u, float[] v)
{
if (u.Length != v.Length)
{
throw new ArgumentException("Vectors have non-matching dimensions");
}
float dot = 0.0f;
float nru = 0.0f;
float nrv = 0.0f;
for (int i = 0; i < u.Length; ++i)
{
dot += u[i] * v[i];
nru += u[i] * u[i];
nrv += v[i] * v[i];
}
var similarity = dot / (float)(Math.Sqrt(nru) * Math.Sqrt(nrv));
return 1 - similarity;
}
///
/// Calculates cosine distance with assumption that u and v are unit vectors.
///
/// Left vector.
/// Right vector.
/// Cosine distance between u and v.
public static float ForUnits(float[] u, float[] v)
{
if (u.Length != v.Length)
{
throw new ArgumentException("Vectors have non-matching dimensions");
}
float dot = 0;
for (int i = 0; i < u.Length; ++i)
{
dot += u[i] * v[i];
}
return 1 - dot;
}
///
/// Calculates cosine distance optimized using SIMD instructions.
///
/// Left vector.
/// Right vector.
/// Cosine distance between u and v.
public static float SIMD(float[] u, float[] v)
{
if (!Vector.IsHardwareAccelerated)
{
throw new NotSupportedException($"SIMD version of {nameof(CosineDistance)} is not supported");
}
if (u.Length != v.Length)
{
throw new ArgumentException("Vectors have non-matching dimensions");
}
float dot = 0;
var norm = default(Vector2);
int step = Vector.Count;
int i, to = u.Length - step;
for (i = 0; i <= to; i += step)
{
var ui = new Vector(u, i);
var vi = new Vector(v, i);
dot += Vector.Dot(ui, vi);
norm.X += Vector.Dot(ui, ui);
norm.Y += Vector.Dot(vi, vi);
}
for (; i < u.Length; ++i)
{
dot += u[i] * v[i];
norm.X += u[i] * u[i];
norm.Y += v[i] * v[i];
}
norm = Vector2.SquareRoot(norm);
float n = (norm.X * norm.Y);
if (n == 0)
{
return 1f;
}
var similarity = dot / n;
return 1f - similarity;
}
///
/// Calculates cosine distance with assumption that u and v are unit vectors using SIMD instructions.
///
/// Left vector.
/// Right vector.
/// Cosine distance between u and v.
public static float SIMDForUnits(float[] u, float[] v)
{
return 1f - DotProduct(ref u, ref v);
}
private static readonly int _vs1 = Vector.Count;
private static readonly int _vs2 = 2 * Vector.Count;
private static readonly int _vs3 = 3 * Vector.Count;
private static readonly int _vs4 = 4 * Vector.Count;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static float DotProduct(ref float[] lhs, ref float[] rhs)
{
float result = 0f;
var count = lhs.Length;
var offset = 0;
while (count >= _vs4)
{
result += Vector.Dot(new Vector(lhs, offset), new Vector(rhs, offset));
result += Vector.Dot(new Vector(lhs, offset + _vs1), new Vector(rhs, offset + _vs1));
result += Vector.Dot(new Vector(lhs, offset + _vs2), new Vector(rhs, offset + _vs2));
result += Vector.Dot(new Vector(lhs, offset + _vs3), new Vector(rhs, offset + _vs3));
if (count == _vs4) return result;
count -= _vs4;
offset += _vs4;
}
if (count >= _vs2)
{
result += Vector.Dot(new Vector(lhs, offset), new Vector(rhs, offset));
result += Vector.Dot(new Vector(lhs, offset + _vs1), new Vector(rhs, offset + _vs1));
if (count == _vs2) return result;
count -= _vs2;
offset += _vs2;
}
if (count >= _vs1)
{
result += Vector.Dot(new Vector(lhs, offset), new Vector(rhs, offset));
if (count == _vs1) return result;
count -= _vs1;
offset += _vs1;
}
if (count > 0)
{
while (count > 0)
{
result += lhs[offset] * rhs[offset];
offset++; count--;
}
}
return result;
}
}
}