From 41321476579a19a751f4f913db210f5583870bb6 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 9 Dec 2021 05:49:01 +0300 Subject: [PATCH] Added functionality of active nodes to HNSW Append HNSW tests in HNSWDemo project --- TestHNSW/HNSWDemo/Program.cs | 95 ++++++++++++++++---------- ZeroLevel.HNSW/Layer.cs | 126 +++++++++++++++++++++++++++++++---- ZeroLevel.HNSW/SmallWorld.cs | 55 +++++++-------- 3 files changed, 197 insertions(+), 79 deletions(-) diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index 48ad3d9..39d8b4e 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -98,33 +98,78 @@ namespace HNSWDemo static void Main(string[] args) { - var dimensionality = 128; - var testCount = 1000; + FilterTest(); + Console.ReadKey(); + } + + static void FilterTest() + { var count = 5000; + var testCount = 1000; + var dimensionality = 128; var samples = Person.GenerateRandom(dimensionality, count); - var sw = new Stopwatch(); - - var test = new VectorsDirectCompare(samples.Select(s => s.Item1).ToList(), CosineDistance.ForUnits); var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - - var batch = samples.ToArray(); - - var ids = world.AddItems(batch.Select(i => i.Item1).ToArray()); - - Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); - for (int bi = 0; bi < batch.Length; bi++) + + var ids = world.AddItems(samples.Select(i => i.Item1).ToArray()); + for (int bi = 0; bi < samples.Count; bi++) { - _database.Add(ids[bi], batch[bi].Item2); + _database.Add(ids[bi], samples[bi].Item2); } Console.WriteLine("Start test"); int K = 200; var vectors = RandomVectors(dimensionality, testCount); + + var activeNodes = _database.Where(pair => pair.Value.Age > 20 && pair.Value.Age < 50 && pair.Value.Gender == Gender.Feemale).Select(pair => pair.Key).ToHashSet(); + + var hits = 0; + var miss = 0; + foreach (var v in vectors) + { + var result = world.Search(v, K, activeNodes); + foreach (var r in result) + { + var record = _database[r.Item1]; + if (record.Gender == Gender.Feemale && record.Age > 20 && record.Age < 50) + { + hits++; + } + else + { + miss++; + } + } + } + Console.WriteLine($"SUCCESS: {hits}"); + Console.WriteLine($"ERROR: {miss}"); + } + + static void AccuracityTest() + { + int K = 200; + var count = 5000; + var testCount = 1000; + var dimensionality = 128; var totalHits = new List(); - var timewatchesHNSW = new List(); var timewatchesNP = new List(); - foreach (var v in vectors) + var timewatchesHNSW = new List(); + var samples = RandomVectors(dimensionality, count); + + var sw = new Stopwatch(); + + var test = new VectorsDirectCompare(samples, CosineDistance.ForUnits); + var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + + sw.Start(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + + Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); + Console.WriteLine("Start test"); + + var test_vectors = RandomVectors(dimensionality, testCount); + foreach (var v in test_vectors) { sw.Restart(); var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); @@ -156,26 +201,6 @@ namespace HNSWDemo Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); - - - - //HNSWFilter filter = new HNSWFilter(ids => ids.Where(id => { var p = _database[id]; return p.Age > 45 && p.Gender == Gender.Feemale; })); - - /*var fackupCount = 0; - foreach (var v in vectors) - { - var result = world.Search(v, 10, filter); - foreach (var r in result) - { - if (_database[r.Item1].Age <= 45 || _database[r.Item1].Gender != Gender.Feemale) - { - Interlocked.Increment(ref fackupCount); - } - } - }*/ - - //Console.WriteLine($"Completed. Fackup count: {fackupCount}"); - Console.ReadKey(); } } } diff --git a/ZeroLevel.HNSW/Layer.cs b/ZeroLevel.HNSW/Layer.cs index 4ff91ea..d7ec5f3 100644 --- a/ZeroLevel.HNSW/Layer.cs +++ b/ZeroLevel.HNSW/Layer.cs @@ -11,20 +11,33 @@ namespace ZeroLevel.HNSW { private readonly NSWOptions _options; private readonly VectorSet _vectors; - private CompactBiDirectionalLinksSet _links = new CompactBiDirectionalLinksSet(); + private readonly CompactBiDirectionalLinksSet _links; /// - /// Count nodes at layer + /// There are links е the layer /// - public int CountLinks => (_links.Count); + internal bool HasLinks => (_links.Count > 0); - public Layer(NSWOptions options, VectorSet vectors) + /// + /// HNSW layer + /// + /// HNSW graph options + /// General vector set + internal Layer(NSWOptions options, VectorSet vectors) { _options = options; _vectors = vectors; + _links = new CompactBiDirectionalLinksSet(); } - public void AddBidirectionallConnectionts(int q, int p, float qpDistance, bool isMapLayer) + /// + /// Adding new bidirectional link + /// + /// New node + /// The node with which the connection will be made + /// + /// + internal void AddBidirectionallConnections(int q, int p, float qpDistance, bool isMapLayer) { // поиск в ширину ближайших узлов к найденному var nearest = _links.FindLinksForId(p).ToArray(); @@ -55,12 +68,15 @@ namespace ZeroLevel.HNSW } } - public void Append(int q) + /// + /// Adding a node with a connection to itself + /// + /// + internal void Append(int q) { _links.Add(q, q, 0); } - #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf /// /// Algorithm 2 @@ -68,7 +84,7 @@ namespace ZeroLevel.HNSW /// query element /// enter points ep /// Output: ef closest neighbors to q - public void RunKnnAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef) + internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef) { /* * v ← ep // set of visited elements @@ -90,7 +106,6 @@ namespace ZeroLevel.HNSW * remove furthest element from W to q * return W */ - var v = new VisitedBitSet(_vectors.Count, _options.M); // v ← ep // set of visited elements v.Add(entryPointId); @@ -143,10 +158,98 @@ namespace ZeroLevel.HNSW v.Clear(); } + /// + /// Algorithm 2 + /// + /// query element + /// enter points ep + /// Output: ef closest neighbors to q + internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef, HashSet activeNodes) + { + /* + * v ← ep // set of visited elements + * C ← ep // set of candidates + * W ← ep // dynamic list of found nearest neighbors + * while │C│ > 0 + * c ← extract nearest element from C to q + * f ← get furthest element from W to q + * if distance(c, q) > distance(f, q) + * break // all elements in W are evaluated + * for each e ∈ neighbourhood(c) at layer lc // update C and W + * if e ∉ v + * v ← v ⋃ e + * f ← get furthest element from W to q + * if distance(e, q) < distance(f, q) or │W│ < ef + * C ← C ⋃ e + * W ← W ⋃ e + * if │W│ > ef + * remove furthest element from W to q + * return W + */ + var v = new VisitedBitSet(_vectors.Count, _options.M); + // v ← ep // set of visited elements + v.Add(entryPointId); + // C ← ep // set of candidates + var C = new Dictionary(); + C.Add(entryPointId, targetCosts(entryPointId)); + // W ← ep // dynamic list of found nearest neighbors + if (activeNodes.Contains(entryPointId)) + { + W.Add(entryPointId, C[entryPointId]); + } + var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); }); + var farthestDistance = new Func(() => { var pair = W.OrderByDescending(e => e.Value).First(); return pair.Value; }); + var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); }); + // run bfs + while (C.Count > 0) + { + // get next candidate to check and expand + var toExpand = popCandidate(); + if (W.Count > 0) + { + if (toExpand.Item2 > farthestDistance()) + { + // the closest candidate is farther than farthest result + break; + } + } + + // expand candidate + var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); + for (int i = 0; i < neighboursIds.Length; ++i) + { + int neighbourId = neighboursIds[i]; + if (!v.Contains(neighbourId)) + { + // enqueue perspective neighbours to expansion list + var neighbourDistance = targetCosts(neighbourId); + if (activeNodes.Contains(neighbourId)) + { + if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance())) + { + W.Add(neighbourId, neighbourDistance); + if (W.Count > ef) + { + fartherPopFromResult(); + } + } + } + if (W.Count < ef) + { + C.Add(neighbourId, neighbourDistance); + } + v.Add(neighbourId); + } + } + } + C.Clear(); + v.Clear(); + } + /// /// Algorithm 3 /// - public IDictionary SELECT_NEIGHBORS_SIMPLE(Func distance, IDictionary candidates, int M) + internal IDictionary SELECT_NEIGHBORS_SIMPLE(Func distance, IDictionary candidates, int M) { var bestN = M; var W = new Dictionary(candidates); @@ -172,7 +275,7 @@ namespace ZeroLevel.HNSW /// flag indicating whether or not to extend candidate list /// flag indicating whether or not to add discarded elements /// Output: M elements selected by the heuristic - public IDictionary SELECT_NEIGHBORS_HEURISTIC(Func distance, IDictionary candidates, int M) + internal IDictionary SELECT_NEIGHBORS_HEURISTIC(Func distance, IDictionary candidates, int M) { // R ← ∅ var R = new Dictionary(); @@ -248,7 +351,6 @@ namespace ZeroLevel.HNSW } #endregion - private IEnumerable GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2); } } \ No newline at end of file diff --git a/ZeroLevel.HNSW/SmallWorld.cs b/ZeroLevel.HNSW/SmallWorld.cs index 437ead9..8349817 100644 --- a/ZeroLevel.HNSW/SmallWorld.cs +++ b/ZeroLevel.HNSW/SmallWorld.cs @@ -55,14 +55,26 @@ namespace ZeroLevel.HNSW } } + /// + /// Search in the graph K for vectors closest to a given vector + /// + /// Given vector + /// Count of elements for search + /// + /// public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet activeNodes = null) { - foreach (var pair in KNearest(vector, k)) + foreach (var pair in KNearest(vector, k, activeNodes)) { yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); } } + /// + /// Adding vectors batch + /// + /// Vectors + /// Vector identifiers in a graph public int[] AddItems(IEnumerable vectors) { _lockGraph.EnterWriteLock(); @@ -81,32 +93,11 @@ namespace ZeroLevel.HNSW } } - public void TestLevelGenerator() - { - var levels = new Dictionary(); - for (int i = 0; i < 10000; i++) - { - var level = _layerLevelGenerator.GetRandomLayer(); - if (levels.ContainsKey(level) == false) - { - levels.Add(level, 1); - } - else - { - levels[level] += 1.0f; - } - } - foreach (var pair in levels.OrderBy(l => l.Key)) - { - Console.WriteLine($"[{pair.Key}]: {pair.Value / 100.0f}% ({pair.Value})"); - } - } - #region https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf /// /// Algorithm 1 /// - public void INSERT(int q) + private void INSERT(int q) { var distance = new Func(candidate => _options.Distance(_vectors[q], _vectors[candidate])); // W ← ∅ // list for the currently found nearest elements @@ -122,7 +113,7 @@ namespace ZeroLevel.HNSW // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа for (int lc = L; lc > l; --lc) { - if (_layers[lc].CountLinks == 0) + if (_layers[lc].HasLinks == false) { _layers[lc].Append(q); ep = q; @@ -130,7 +121,7 @@ namespace ZeroLevel.HNSW else { // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - _layers[lc].RunKnnAtLayer(ep, distance, W, 1); + _layers[lc].KNearestAtLayer(ep, distance, W, 1); // ep ← get the nearest element from W to q var nearest = W.OrderBy(p => p.Value).First(); ep = nearest.Key; @@ -142,7 +133,7 @@ namespace ZeroLevel.HNSW // connecting new node to the small world for (int lc = Math.Min(L, l); lc >= 0; --lc) { - if (_layers[lc].CountLinks == 0) + if (_layers[lc].HasLinks == false) { _layers[lc].Append(q); ep = q; @@ -150,7 +141,7 @@ namespace ZeroLevel.HNSW else { // W ← SEARCH - LAYER(q, ep, efConstruction, lc) - _layers[lc].RunKnnAtLayer(ep, distance, W, _options.EFConstruction); + _layers[lc].KNearestAtLayer(ep, distance, W, _options.EFConstruction); // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 var neighbors = SelectBestForConnecting(lc, distance, W); // add bidirectionall connectionts from neighbors to q at layer lc @@ -158,7 +149,7 @@ namespace ZeroLevel.HNSW foreach (var e in neighbors) { // eConn ← neighbourhood(e) at layer lc - _layers[lc].AddBidirectionallConnectionts(q, e.Key, e.Value, lc == 0); + _layers[lc].AddBidirectionallConnections(q, e.Key, e.Value, lc == 0); // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer if (e.Value < epDist) { @@ -195,7 +186,7 @@ namespace ZeroLevel.HNSW /// /// The level of the layer. /// The maximum number of connections. - internal int GetM(int layer) + private int GetM(int layer) { return layer == 0 ? 2 * _options.M : _options.M; } @@ -210,7 +201,7 @@ namespace ZeroLevel.HNSW /// /// Algorithm 5 /// - internal IEnumerable<(int, float)> KNearest(TItem q, int k) + private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet activeNodes = null) { _lockGraph.EnterReadLock(); try @@ -231,13 +222,13 @@ namespace ZeroLevel.HNSW for (int layer = L; layer > 0; --layer) { // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - _layers[layer].RunKnnAtLayer(ep, distance, W, 1); + _layers[layer].KNearestAtLayer(ep, distance, W, 1); // ep ← get nearest element from W to q ep = W.OrderBy(p => p.Value).First().Key; W.Clear(); } // W ← SEARCH-LAYER(q, ep, ef, lc =0) - _layers[0].RunKnnAtLayer(ep, distance, W, k); + _layers[0].KNearestAtLayer(ep, distance, W, k, activeNodes); // return K nearest elements from W to q return W.Select(p => (p.Key, p.Value)); }