using System; using System.Collections.Generic; using System.Linq; using ZeroLevel.HNSW; namespace HNSWDemo.Utils { public class VectorsDirectCompare { private const int HALF_LONG_BITS = 32; private readonly IList<float[]> _vectors; private readonly Func<float[], float[], float> _distance; public VectorsDirectCompare(List<float[]> vectors, Func<float[], float[], float> distance) { _vectors = vectors; _distance = distance; } public IEnumerable<(int, float)> KNearest(float[] v, int k) { var weights = new Dictionary<int, float>(); for (int i = 0; i < _vectors.Count; i++) { var d = _distance(v, _vectors[i]); weights[i] = d; } return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); } public List<HashSet<int>> DetectClusters() { var links = new SortedList<long, float>(); for (int i = 0; i < _vectors.Count; i++) { for (int j = i + 1; j < _vectors.Count; j++) { long k = (((long)(i)) << HALF_LONG_BITS) + j; links.Add(k, _distance(_vectors[i], _vectors[j])); } } // 1. Find R - bound between intra-cluster distances and out-of-cluster distances var histogram = new Histogram(HistogramMode.SQRT, links.Values); int threshold = histogram.CuttOff(); var min = histogram.Bounds[threshold - 1]; var max = histogram.Bounds[threshold]; var R = (max + min) / 2; // 2. Get links with distances less than R var resultLinks = new SortedList<long, float>(); foreach (var pair in links) { if (pair.Value < R) { resultLinks.Add(pair.Key, pair.Value); } } // 3. Extract clusters List<HashSet<int>> clusters = new List<HashSet<int>>(); foreach (var pair in resultLinks) { var k = pair.Key; var id1 = (int)(k >> HALF_LONG_BITS); var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); bool found = false; foreach (var c in clusters) { if (c.Contains(id1)) { c.Add(id2); found = true; break; } else if (c.Contains(id2)) { c.Add(id1); found = true; break; } } if (found == false) { var c = new HashSet<int>(); c.Add(id1); c.Add(id2); clusters.Add(c); } } return clusters; } } }