|
|
|
|
using System;
|
|
|
|
|
using System.Collections.Generic;
|
|
|
|
|
using System.Linq;
|
|
|
|
|
using ZeroLevel.HNSW;
|
|
|
|
|
|
|
|
|
|
namespace HNSWDemo.Utils
|
|
|
|
|
{
|
|
|
|
|
public class VectorsDirectCompare
|
|
|
|
|
{
|
|
|
|
|
private const int HALF_LONG_BITS = 32;
|
|
|
|
|
private readonly IList<float[]> _vectors;
|
|
|
|
|
private readonly Func<float[], float[], float> _distance;
|
|
|
|
|
|
|
|
|
|
public VectorsDirectCompare(List<float[]> vectors, Func<float[], float[], float> distance)
|
|
|
|
|
{
|
|
|
|
|
_vectors = vectors;
|
|
|
|
|
_distance = distance;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public IEnumerable<(int, float)> KNearest(float[] v, int k)
|
|
|
|
|
{
|
|
|
|
|
var weights = new Dictionary<int, float>();
|
|
|
|
|
for (int i = 0; i < _vectors.Count; i++)
|
|
|
|
|
{
|
|
|
|
|
var d = _distance(v, _vectors[i]);
|
|
|
|
|
weights[i] = d;
|
|
|
|
|
}
|
|
|
|
|
return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public List<HashSet<int>> DetectClusters()
|
|
|
|
|
{
|
|
|
|
|
var links = new SortedList<long, float>();
|
|
|
|
|
for (int i = 0; i < _vectors.Count; i++)
|
|
|
|
|
{
|
|
|
|
|
for (int j = i + 1; j < _vectors.Count; j++)
|
|
|
|
|
{
|
|
|
|
|
long k = (((long)(i)) << HALF_LONG_BITS) + j;
|
|
|
|
|
links.Add(k, _distance(_vectors[i], _vectors[j]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances
|
|
|
|
|
var histogram = new Histogram(HistogramMode.SQRT, links.Values);
|
|
|
|
|
int threshold = histogram.CuttOff();
|
|
|
|
|
var min = histogram.Bounds[threshold - 1];
|
|
|
|
|
var max = histogram.Bounds[threshold];
|
|
|
|
|
var R = (max + min) / 2;
|
|
|
|
|
|
|
|
|
|
// 2. Get links with distances less than R
|
|
|
|
|
var resultLinks = new SortedList<long, float>();
|
|
|
|
|
foreach (var pair in links)
|
|
|
|
|
{
|
|
|
|
|
if (pair.Value < R)
|
|
|
|
|
{
|
|
|
|
|
resultLinks.Add(pair.Key, pair.Value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 3. Extract clusters
|
|
|
|
|
List<HashSet<int>> clusters = new List<HashSet<int>>();
|
|
|
|
|
foreach (var pair in resultLinks)
|
|
|
|
|
{
|
|
|
|
|
var k = pair.Key;
|
|
|
|
|
var id1 = (int)(k >> HALF_LONG_BITS);
|
|
|
|
|
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
|
|
|
|
|
|
|
|
|
|
bool found = false;
|
|
|
|
|
foreach (var c in clusters)
|
|
|
|
|
{
|
|
|
|
|
if (c.Contains(id1))
|
|
|
|
|
{
|
|
|
|
|
c.Add(id2);
|
|
|
|
|
found = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
else if (c.Contains(id2))
|
|
|
|
|
{
|
|
|
|
|
c.Add(id1);
|
|
|
|
|
found = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (found == false)
|
|
|
|
|
{
|
|
|
|
|
var c = new HashSet<int>();
|
|
|
|
|
c.Add(id1);
|
|
|
|
|
c.Add(id2);
|
|
|
|
|
clusters.Add(c);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return clusters;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|