HNSW. Append map

Map stores the correspondence between the object feature and the vector identifier.
pull/1/head
unknown 3 years ago
parent c128396ac1
commit d1569628b6

@ -95,8 +95,6 @@ namespace HNSWDemo
return vectors; return vectors;
} }
private static Dictionary<int, Person> _database = new Dictionary<int, Person>();
static void Main(string[] args) static void Main(string[] args)
{ {
FilterTest(); FilterTest();
@ -367,28 +365,31 @@ namespace HNSWDemo
var dimensionality = 128; var dimensionality = 128;
var samples = Person.GenerateRandom(dimensionality, count); var samples = Person.GenerateRandom(dimensionality, count);
var testDict = samples.ToDictionary(s => s.Item2.Number, s => s.Item2);
var map = new HNSWMap<long>();
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
var ids = world.AddItems(samples.Select(i => i.Item1).ToArray()); var ids = world.AddItems(samples.Select(i => i.Item1).ToArray());
for (int bi = 0; bi < samples.Count; bi++) for (int bi = 0; bi < samples.Count; bi++)
{ {
_database.Add(ids[bi], samples[bi].Item2); map.Append(samples[bi].Item2.Number, ids[bi]);
} }
Console.WriteLine("Start test"); Console.WriteLine("Start test");
int K = 200; int K = 200;
var vectors = RandomVectors(dimensionality, testCount); var vectors = RandomVectors(dimensionality, testCount);
var context = new SearchContext().SetActiveNodes(_database.Where(pair => pair.Value.Age > 20 && pair.Value.Age < 50 && pair.Value.Gender == Gender.Feemale).Select(pair => pair.Key)); var context = new SearchContext().SetActiveNodes(map.ConvertFeaturesToIds(samples.Where(p => p.Item2.Age > 20 && p.Item2.Age < 50 && p.Item2.Gender == Gender.Feemale).Select(p => p.Item2.Number)));
var hits = 0; var hits = 0;
var miss = 0; var miss = 0;
foreach (var v in vectors) foreach (var v in vectors)
{ {
var result = world.Search(v, K, context); var numbers = map.ConvertIdsToFeatures( world.Search(v, K, context).Select(r=>r.Item1));
foreach (var r in result) foreach (var r in numbers)
{ {
var record = _database[r.Item1]; var record = testDict[r];
if (record.Gender == Gender.Feemale && record.Age > 20 && record.Age < 50) if (record.Gender == Gender.Feemale && record.Age > 20 && record.Age < 50)
{ {
hits++; hits++;

@ -0,0 +1,44 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
namespace ZeroLevel.HNSW
{
// object -> vector -> vectorId
// HNSW vectorId + vector
// Map object feature - vectorId
public class HNSWMap<TFeature>
{
private readonly ConcurrentDictionary<TFeature, int> _map = new ConcurrentDictionary<TFeature, int>();
private readonly ConcurrentDictionary<int, TFeature> _reverse_map = new ConcurrentDictionary<int, TFeature>();
public void Append(TFeature feature, int vectorId)
{
_map[feature] = vectorId;
_reverse_map[vectorId] = feature;
}
public IEnumerable<int> ConvertFeaturesToIds(IEnumerable<TFeature> features)
{
int id;
foreach (var feature in features)
{
if (_map.TryGetValue(feature, out id))
{
yield return id;
}
}
}
public IEnumerable<TFeature> ConvertIdsToFeatures(IEnumerable<int> ids)
{
TFeature feature;
foreach (var id in ids)
{
if (_reverse_map.TryGetValue(id, out feature))
{
yield return feature;
}
}
}
}
}

@ -16,7 +16,7 @@ namespace ZeroLevel.HNSW
} }
private HashSet<int> _activeNodes; private HashSet<int> _activeNodes;
private HashSet<int> _inactiveNodes; private HashSet<int> _entryNodes;
private Mode _mode; private Mode _mode;
public SearchContext() public SearchContext()
@ -30,8 +30,8 @@ namespace ZeroLevel.HNSW
switch (_mode) switch (_mode)
{ {
case Mode.ActiveCheck: return _activeNodes.Contains(nodeId); case Mode.ActiveCheck: return _activeNodes.Contains(nodeId);
case Mode.InactiveCheck: return _inactiveNodes.Contains(nodeId) == false; case Mode.InactiveCheck: return _entryNodes.Contains(nodeId) == false;
case Mode.ActiveInactiveCheck: return _inactiveNodes.Contains(nodeId) == false && _activeNodes.Contains(nodeId); case Mode.ActiveInactiveCheck: return _entryNodes.Contains(nodeId) == false && _activeNodes.Contains(nodeId);
} }
return nodeId >= 0; return nodeId >= 0;
} }
@ -57,15 +57,15 @@ namespace ZeroLevel.HNSW
return this; return this;
} }
public SearchContext SetInactiveNodes(IEnumerable<int> inactiveNodes) public SearchContext SetEntryPointsNodes(IEnumerable<int> entryNodes)
{ {
if (inactiveNodes != null && inactiveNodes.Any()) if (entryNodes != null && entryNodes.Any())
{ {
if (_mode == Mode.InactiveCheck || _mode == Mode.ActiveInactiveCheck) if (_mode == Mode.InactiveCheck || _mode == Mode.ActiveInactiveCheck)
{ {
throw new InvalidOperationException("Inctive nodes are already defined"); throw new InvalidOperationException("Inctive nodes are already defined");
} }
_inactiveNodes = new HashSet<int>(inactiveNodes); _entryNodes = new HashSet<int>(entryNodes);
if (_mode == Mode.None) if (_mode == Mode.None)
{ {
_mode = Mode.InactiveCheck; _mode = Mode.InactiveCheck;

Loading…
Cancel
Save

Powered by TurnKey Linux.