diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index ee4610f..48ad3d9 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -2,13 +2,35 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; -using System.Threading; using ZeroLevel.HNSW; namespace HNSWDemo { class Program { + public class VectorsDirectCompare + { + private readonly IList _vectors; + private readonly Func _distance; + + public VectorsDirectCompare(List vectors, Func distance) + { + _vectors = vectors; + _distance = distance; + } + + public IEnumerable<(int, float)> KNearest(float[] v, int k) + { + var weights = new Dictionary(); + for (int i = 0; i < _vectors.Count; i++) + { + var d = _distance(v, _vectors[i]); + weights[i] = d; + } + return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); + } + } + public enum Gender { Unknown, Male, Feemale @@ -77,28 +99,65 @@ namespace HNSWDemo static void Main(string[] args) { var dimensionality = 128; - var testCount = 5000; - var count = 10000; - var batchSize = 1000; + var testCount = 1000; + var count = 5000; var samples = Person.GenerateRandom(dimensionality, count); var sw = new Stopwatch(); - var world = new SmallWorld(NSWOptions.Create(6, 10, 200, 200, CosineDistance.ForUnits, false, false, selectionHeuristic: NeighbourSelectionHeuristic.SelectHeuristic)); - for (int i = 0; i < (count / batchSize); i++) + var test = new VectorsDirectCompare(samples.Select(s => s.Item1).ToList(), CosineDistance.ForUnits); + var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + + var batch = samples.ToArray(); + + var ids = world.AddItems(batch.Select(i => i.Item1).ToArray()); + + Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); + for (int bi = 0; bi < batch.Length; bi++) + { + _database.Add(ids[bi], batch[bi].Item2); + } + + Console.WriteLine("Start test"); + int K = 200; + var vectors = RandomVectors(dimensionality, testCount); + var totalHits = new List(); + var timewatchesHNSW = new List(); + var timewatchesNP = new List(); + foreach (var v in vectors) { - var batch = samples.Skip(i * batchSize).Take(batchSize).ToArray(); sw.Restart(); - var ids = world.AddItems(batch.Select(i => i.Item1).ToArray()); + var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); sw.Stop(); - Console.WriteLine($"Batch [{i}]. Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); - for (int bi = 0; bi < batch.Length; bi++) + timewatchesNP.Add(sw.ElapsedMilliseconds); + + sw.Restart(); + var result = world.Search(v, K); + sw.Stop(); + timewatchesHNSW.Add(sw.ElapsedMilliseconds); + var hits = 0; + foreach (var r in result) { - _database.Add(ids[bi], batch[bi].Item2); + if (gt.ContainsKey(r.Item1)) + { + hits++; + } } + totalHits.Add(hits); } + Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%"); + Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%"); + Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%"); + + Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); + Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); + Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); + + Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); + Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); + Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); + - var vectors = RandomVectors(dimensionality, testCount); //HNSWFilter filter = new HNSWFilter(ids => ids.Where(id => { var p = _database[id]; return p.Age > 45 && p.Gender == Gender.Feemale; })); diff --git a/ZeroLevel.HNSW/Layer.cs b/ZeroLevel.HNSW/Layer.cs index 577c6b0..4ff91ea 100644 --- a/ZeroLevel.HNSW/Layer.cs +++ b/ZeroLevel.HNSW/Layer.cs @@ -16,7 +16,7 @@ namespace ZeroLevel.HNSW /// /// Count nodes at layer /// - public int Count => (_links.Count >> 1); + public int CountLinks => (_links.Count); public Layer(NSWOptions options, VectorSet vectors) { @@ -24,13 +24,13 @@ namespace ZeroLevel.HNSW _vectors = vectors; } - public void AddBidirectionallConnectionts(int q, int p, float qpDistance) + public void AddBidirectionallConnectionts(int q, int p, float qpDistance, bool isMapLayer) { // поиск в ширину ближайших узлов к найденному var nearest = _links.FindLinksForId(p).ToArray(); // если у найденного узла максимальное количество связей // if │eConn│ > Mmax // shrink connections of e - if (nearest.Length >= _options.M) + if (nearest.Length >= (isMapLayer ? _options.M * 2 : _options.M)) { // ищем связь с самой большой дистанцией float distance = nearest[0].Item3; @@ -55,6 +55,12 @@ namespace ZeroLevel.HNSW } } + public void Append(int q) + { + _links.Add(q, q, 0); + } + + #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf /// /// Algorithm 2 diff --git a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs index 026200b..1ea22db 100644 --- a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs +++ b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs @@ -43,8 +43,10 @@ namespace ZeroLevel.HNSW { _set.Remove(k1old); _set.Remove(k2old); - _set.Add(k1new, distance); - _set.Add(k2new, distance); + if (!_set.ContainsKey(k1new)) + _set.Add(k1new, distance); + if (!_set.ContainsKey(k2new)) + _set.Add(k2new, distance); } finally { diff --git a/ZeroLevel.HNSW/SmallWorld.cs b/ZeroLevel.HNSW/SmallWorld.cs index 477957d..437ead9 100644 --- a/ZeroLevel.HNSW/SmallWorld.cs +++ b/ZeroLevel.HNSW/SmallWorld.cs @@ -7,7 +7,7 @@ namespace ZeroLevel.HNSW { public class ProbabilityLayerNumberGenerator { - private const float DIVIDER = 4.362f; + private const float DIVIDER = 3.361f; private readonly float[] _probabilities; public ProbabilityLayerNumberGenerator(int maxLayers, int M) @@ -38,11 +38,8 @@ namespace ZeroLevel.HNSW private readonly NSWOptions _options; private readonly VectorSet _vectors; private readonly Layer[] _layers; - - private Layer EnterPointsLayer => _layers[_layers.Length - 1]; - private Layer LastLayer => _layers[0]; - private int EntryPoint = -1; - private int MaxLayer = -1; + private int EntryPoint = 0; + private int MaxLayer = 0; private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator; private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim(); @@ -58,9 +55,12 @@ namespace ZeroLevel.HNSW } } - public IEnumerable<(int, TItem[])> Search(TItem vector, int k, HashSet activeNodes = null) + public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet activeNodes = null) { - return Enumerable.Empty<(int, TItem[])>(); + foreach (var pair in KNearest(vector, k)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } } public int[] AddItems(IEnumerable vectors) @@ -109,57 +109,67 @@ namespace ZeroLevel.HNSW public void INSERT(int q) { var distance = new Func(candidate => _options.Distance(_vectors[q], _vectors[candidate])); - // W ← ∅ // list for the currently found nearest elements IDictionary W = new Dictionary(); // ep ← get enter point for hnsw - var ep = EntryPoint == -1 ? 0 : EntryPoint; - var epDist = 0.0f; + var ep = EntryPoint; + var epDist = distance(ep); // L ← level of ep // top layer for hnsw var L = MaxLayer; // l ← ⌊-ln(unif(0..1))∙mL⌋ // new element’s level - int l = _layerLevelGenerator.GetRandomLayer(); - if (L == -1) - { - L = l; - MaxLayer = l; - } + int l = _layerLevelGenerator.GetRandomLayer(); // for lc ← L … l+1 // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа for (int lc = L; lc > l; --lc) { - // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - _layers[lc].RunKnnAtLayer(ep, distance, W, 1); - // ep ← get the nearest element from W to q - var nearest = W.OrderBy(p => p.Value).First(); - ep = nearest.Key; - epDist = nearest.Value; - W.Clear(); + if (_layers[lc].CountLinks == 0) + { + _layers[lc].Append(q); + ep = q; + } + else + { + // W ← SEARCH-LAYER(q, ep, ef = 1, lc) + _layers[lc].RunKnnAtLayer(ep, distance, W, 1); + // ep ← get the nearest element from W to q + var nearest = W.OrderBy(p => p.Value).First(); + ep = nearest.Key; + epDist = nearest.Value; + W.Clear(); + } } //for lc ← min(L, l) … 0 // connecting new node to the small world for (int lc = Math.Min(L, l); lc >= 0; --lc) { - // W ← SEARCH - LAYER(q, ep, efConstruction, lc) - _layers[lc].RunKnnAtLayer(ep, distance, W, _options.EFConstruction); - // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 - var neighbors = SelectBestForConnecting(lc, distance, W);; - // add bidirectionall connectionts from neighbors to q at layer lc - // for each e ∈ neighbors // shrink connections if needed - foreach (var e in neighbors) + if (_layers[lc].CountLinks == 0) { - // eConn ← neighbourhood(e) at layer lc - _layers[lc].AddBidirectionallConnectionts(q, e.Key, e.Value); - // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer - if (e.Value < epDist) + _layers[lc].Append(q); + ep = q; + } + else + { + // W ← SEARCH - LAYER(q, ep, efConstruction, lc) + _layers[lc].RunKnnAtLayer(ep, distance, W, _options.EFConstruction); + // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 + var neighbors = SelectBestForConnecting(lc, distance, W); + // add bidirectionall connectionts from neighbors to q at layer lc + // for each e ∈ neighbors // shrink connections if needed + foreach (var e in neighbors) { - ep = e.Key; - epDist = e.Value; + // eConn ← neighbourhood(e) at layer lc + _layers[lc].AddBidirectionallConnectionts(q, e.Key, e.Value, lc == 0); + // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer + if (e.Value < epDist) + { + ep = e.Key; + epDist = e.Value; + } } + // ep ← W + ep = W.OrderBy(p => p.Value).First().Key; + W.Clear(); } - // ep ← W - ep = W.OrderBy(p => p.Value).First().Key; - W.Clear(); } // if l > L if (l > L) @@ -192,7 +202,7 @@ namespace ZeroLevel.HNSW private IDictionary SelectBestForConnecting(int layer, Func distance, IDictionary candidates) { - if(_options.SelectionHeuristic == NeighbourSelectionHeuristic.SelectSimple) + if (_options.SelectionHeuristic == NeighbourSelectionHeuristic.SelectSimple) return _layers[layer].SELECT_NEIGHBORS_SIMPLE(distance, candidates, GetM(layer)); return _layers[layer].SELECT_NEIGHBORS_HEURISTIC(distance, candidates, GetM(layer)); }