From d1569628b6198ab0d48ac9a488d261565f530cb7 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 10 Dec 2021 05:48:30 +0300 Subject: [PATCH] HNSW. Append map Map stores the correspondence between the object feature and the vector identifier. --- TestHNSW/HNSWDemo/Program.cs | 19 ++++++------ ZeroLevel.HNSW/HNSWMap.cs | 44 +++++++++++++++++++++++++++ ZeroLevel.HNSW/Model/SearchContext.cs | 12 ++++---- 3 files changed, 60 insertions(+), 15 deletions(-) create mode 100644 ZeroLevel.HNSW/HNSWMap.cs diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index 090f6a6..1f40801 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -95,8 +95,6 @@ namespace HNSWDemo return vectors; } - private static Dictionary _database = new Dictionary(); - static void Main(string[] args) { FilterTest(); @@ -200,7 +198,7 @@ namespace HNSWDemo { world.Serialize(ms); dump = ms.ToArray(); - } + } ReadOnlySmallWorld compactWorld; using (var ms = new MemoryStream(dump)) @@ -211,7 +209,7 @@ namespace HNSWDemo // Compare worlds outputs int K = 200; var hits = 0; - var miss = 0; + var miss = 0; var testCount = 2000; var sw = new Stopwatch(); @@ -367,28 +365,31 @@ namespace HNSWDemo var dimensionality = 128; var samples = Person.GenerateRandom(dimensionality, count); + var testDict = samples.ToDictionary(s => s.Item2.Number, s => s.Item2); + + var map = new HNSWMap(); var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); var ids = world.AddItems(samples.Select(i => i.Item1).ToArray()); for (int bi = 0; bi < samples.Count; bi++) { - _database.Add(ids[bi], samples[bi].Item2); + map.Append(samples[bi].Item2.Number, ids[bi]); } Console.WriteLine("Start test"); int K = 200; var vectors = RandomVectors(dimensionality, testCount); - var context = new SearchContext().SetActiveNodes(_database.Where(pair => pair.Value.Age > 20 && pair.Value.Age < 50 && pair.Value.Gender == Gender.Feemale).Select(pair => pair.Key)); + var context = new SearchContext().SetActiveNodes(map.ConvertFeaturesToIds(samples.Where(p => p.Item2.Age > 20 && p.Item2.Age < 50 && p.Item2.Gender == Gender.Feemale).Select(p => p.Item2.Number))); var hits = 0; var miss = 0; foreach (var v in vectors) { - var result = world.Search(v, K, context); - foreach (var r in result) + var numbers = map.ConvertIdsToFeatures( world.Search(v, K, context).Select(r=>r.Item1)); + foreach (var r in numbers) { - var record = _database[r.Item1]; + var record = testDict[r]; if (record.Gender == Gender.Feemale && record.Age > 20 && record.Age < 50) { hits++; diff --git a/ZeroLevel.HNSW/HNSWMap.cs b/ZeroLevel.HNSW/HNSWMap.cs new file mode 100644 index 0000000..129151c --- /dev/null +++ b/ZeroLevel.HNSW/HNSWMap.cs @@ -0,0 +1,44 @@ +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace ZeroLevel.HNSW +{ + // object -> vector -> vectorId + // HNSW vectorId + vector + // Map object feature - vectorId + public class HNSWMap + { + private readonly ConcurrentDictionary _map = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _reverse_map = new ConcurrentDictionary(); + + public void Append(TFeature feature, int vectorId) + { + _map[feature] = vectorId; + _reverse_map[vectorId] = feature; + } + + public IEnumerable ConvertFeaturesToIds(IEnumerable features) + { + int id; + foreach (var feature in features) + { + if (_map.TryGetValue(feature, out id)) + { + yield return id; + } + } + } + + public IEnumerable ConvertIdsToFeatures(IEnumerable ids) + { + TFeature feature; + foreach (var id in ids) + { + if (_reverse_map.TryGetValue(id, out feature)) + { + yield return feature; + } + } + } + } +} diff --git a/ZeroLevel.HNSW/Model/SearchContext.cs b/ZeroLevel.HNSW/Model/SearchContext.cs index 4cfce40..13661b4 100644 --- a/ZeroLevel.HNSW/Model/SearchContext.cs +++ b/ZeroLevel.HNSW/Model/SearchContext.cs @@ -16,7 +16,7 @@ namespace ZeroLevel.HNSW } private HashSet _activeNodes; - private HashSet _inactiveNodes; + private HashSet _entryNodes; private Mode _mode; public SearchContext() @@ -30,8 +30,8 @@ namespace ZeroLevel.HNSW switch (_mode) { case Mode.ActiveCheck: return _activeNodes.Contains(nodeId); - case Mode.InactiveCheck: return _inactiveNodes.Contains(nodeId) == false; - case Mode.ActiveInactiveCheck: return _inactiveNodes.Contains(nodeId) == false && _activeNodes.Contains(nodeId); + case Mode.InactiveCheck: return _entryNodes.Contains(nodeId) == false; + case Mode.ActiveInactiveCheck: return _entryNodes.Contains(nodeId) == false && _activeNodes.Contains(nodeId); } return nodeId >= 0; } @@ -57,15 +57,15 @@ namespace ZeroLevel.HNSW return this; } - public SearchContext SetInactiveNodes(IEnumerable inactiveNodes) + public SearchContext SetEntryPointsNodes(IEnumerable entryNodes) { - if (inactiveNodes != null && inactiveNodes.Any()) + if (entryNodes != null && entryNodes.Any()) { if (_mode == Mode.InactiveCheck || _mode == Mode.ActiveInactiveCheck) { throw new InvalidOperationException("Inctive nodes are already defined"); } - _inactiveNodes = new HashSet(inactiveNodes); + _entryNodes = new HashSet(entryNodes); if (_mode == Mode.None) { _mode = Mode.InactiveCheck;