using System; using System.Collections.Generic; using System.Linq; using ZeroLevel.HNSW; namespace HNSWDemo.Utils { public class VectorsDirectCompare { private const int HALF_LONG_BITS = 32; private readonly IList _vectors; private readonly Func _distance; public VectorsDirectCompare(List vectors, Func distance) { _vectors = vectors; _distance = distance; } public IEnumerable<(int, float)> KNearest(float[] v, int k) { var weights = new Dictionary(); for (int i = 0; i < _vectors.Count; i++) { var d = _distance(v, _vectors[i]); weights[i] = d; } return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); } public List> DetectClusters() { var links = new SortedList(); for (int i = 0; i < _vectors.Count; i++) { for (int j = i + 1; j < _vectors.Count; j++) { long k = (((long)(i)) << HALF_LONG_BITS) + j; links.Add(k, _distance(_vectors[i], _vectors[j])); } } // 1. Find R - bound between intra-cluster distances and out-of-cluster distances var histogram = new Histogram(HistogramMode.SQRT, links.Values); int threshold = histogram.CuttOff(); var min = histogram.Bounds[threshold - 1]; var max = histogram.Bounds[threshold]; var R = (max + min) / 2; // 2. Get links with distances less than R var resultLinks = new SortedList(); foreach (var pair in links) { if (pair.Value < R) { resultLinks.Add(pair.Key, pair.Value); } } // 3. Extract clusters List> clusters = new List>(); foreach (var pair in resultLinks) { var k = pair.Key; var id1 = (int)(k >> HALF_LONG_BITS); var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); bool found = false; foreach (var c in clusters) { if (c.Contains(id1)) { c.Add(id2); found = true; break; } else if (c.Contains(id2)) { c.Add(id1); found = true; break; } } if (found == false) { var c = new HashSet(); c.Add(id1); c.Add(id2); clusters.Add(c); } } return clusters; } } }