pull/1/head
unknown 3 years ago
parent 0e2f693b87
commit db076efbb3

@ -2,13 +2,35 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Threading;
using ZeroLevel.HNSW; using ZeroLevel.HNSW;
namespace HNSWDemo namespace HNSWDemo
{ {
class Program class Program
{ {
public class VectorsDirectCompare
{
private readonly IList<float[]> _vectors;
private readonly Func<float[], float[], float> _distance;
public VectorsDirectCompare(List<float[]> vectors, Func<float[], float[], float> distance)
{
_vectors = vectors;
_distance = distance;
}
public IEnumerable<(int, float)> KNearest(float[] v, int k)
{
var weights = new Dictionary<int, float>();
for (int i = 0; i < _vectors.Count; i++)
{
var d = _distance(v, _vectors[i]);
weights[i] = d;
}
return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value));
}
}
public enum Gender public enum Gender
{ {
Unknown, Male, Feemale Unknown, Male, Feemale
@ -77,28 +99,65 @@ namespace HNSWDemo
static void Main(string[] args) static void Main(string[] args)
{ {
var dimensionality = 128; var dimensionality = 128;
var testCount = 5000; var testCount = 1000;
var count = 10000; var count = 5000;
var batchSize = 1000;
var samples = Person.GenerateRandom(dimensionality, count); var samples = Person.GenerateRandom(dimensionality, count);
var sw = new Stopwatch(); var sw = new Stopwatch();
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 10, 200, 200, CosineDistance.ForUnits, false, false, selectionHeuristic: NeighbourSelectionHeuristic.SelectHeuristic));
for (int i = 0; i < (count / batchSize); i++) var test = new VectorsDirectCompare(samples.Select(s => s.Item1).ToList(), CosineDistance.ForUnits);
{ var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
var batch = samples.Skip(i * batchSize).Take(batchSize).ToArray();
sw.Restart(); var batch = samples.ToArray();
var ids = world.AddItems(batch.Select(i => i.Item1).ToArray()); var ids = world.AddItems(batch.Select(i => i.Item1).ToArray());
sw.Stop();
Console.WriteLine($"Batch [{i}]. Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms");
for (int bi = 0; bi < batch.Length; bi++) for (int bi = 0; bi < batch.Length; bi++)
{ {
_database.Add(ids[bi], batch[bi].Item2); _database.Add(ids[bi], batch[bi].Item2);
} }
}
Console.WriteLine("Start test");
int K = 200;
var vectors = RandomVectors(dimensionality, testCount); var vectors = RandomVectors(dimensionality, testCount);
var totalHits = new List<int>();
var timewatchesHNSW = new List<float>();
var timewatchesNP = new List<float>();
foreach (var v in vectors)
{
sw.Restart();
var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2);
sw.Stop();
timewatchesNP.Add(sw.ElapsedMilliseconds);
sw.Restart();
var result = world.Search(v, K);
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
var hits = 0;
foreach (var r in result)
{
if (gt.ContainsKey(r.Item1))
{
hits++;
}
}
totalHits.Add(hits);
}
Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%");
Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%");
Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms");
Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms");
Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms");
//HNSWFilter filter = new HNSWFilter(ids => ids.Where(id => { var p = _database[id]; return p.Age > 45 && p.Gender == Gender.Feemale; })); //HNSWFilter filter = new HNSWFilter(ids => ids.Where(id => { var p = _database[id]; return p.Age > 45 && p.Gender == Gender.Feemale; }));

@ -16,7 +16,7 @@ namespace ZeroLevel.HNSW
/// <summary> /// <summary>
/// Count nodes at layer /// Count nodes at layer
/// </summary> /// </summary>
public int Count => (_links.Count >> 1); public int CountLinks => (_links.Count);
public Layer(NSWOptions<TItem> options, VectorSet<TItem> vectors) public Layer(NSWOptions<TItem> options, VectorSet<TItem> vectors)
{ {
@ -24,13 +24,13 @@ namespace ZeroLevel.HNSW
_vectors = vectors; _vectors = vectors;
} }
public void AddBidirectionallConnectionts(int q, int p, float qpDistance) public void AddBidirectionallConnectionts(int q, int p, float qpDistance, bool isMapLayer)
{ {
// поиск в ширину ближайших узлов к найденному // поиск в ширину ближайших узлов к найденному
var nearest = _links.FindLinksForId(p).ToArray(); var nearest = _links.FindLinksForId(p).ToArray();
// если у найденного узла максимальное количество связей // если у найденного узла максимальное количество связей
// if │eConn│ > Mmax // shrink connections of e // if │eConn│ > Mmax // shrink connections of e
if (nearest.Length >= _options.M) if (nearest.Length >= (isMapLayer ? _options.M * 2 : _options.M))
{ {
// ищем связь с самой большой дистанцией // ищем связь с самой большой дистанцией
float distance = nearest[0].Item3; float distance = nearest[0].Item3;
@ -55,6 +55,12 @@ namespace ZeroLevel.HNSW
} }
} }
public void Append(int q)
{
_links.Add(q, q, 0);
}
#region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary> /// <summary>
/// Algorithm 2 /// Algorithm 2

@ -43,7 +43,9 @@ namespace ZeroLevel.HNSW
{ {
_set.Remove(k1old); _set.Remove(k1old);
_set.Remove(k2old); _set.Remove(k2old);
if (!_set.ContainsKey(k1new))
_set.Add(k1new, distance); _set.Add(k1new, distance);
if (!_set.ContainsKey(k2new))
_set.Add(k2new, distance); _set.Add(k2new, distance);
} }
finally finally

@ -7,7 +7,7 @@ namespace ZeroLevel.HNSW
{ {
public class ProbabilityLayerNumberGenerator public class ProbabilityLayerNumberGenerator
{ {
private const float DIVIDER = 4.362f; private const float DIVIDER = 3.361f;
private readonly float[] _probabilities; private readonly float[] _probabilities;
public ProbabilityLayerNumberGenerator(int maxLayers, int M) public ProbabilityLayerNumberGenerator(int maxLayers, int M)
@ -38,11 +38,8 @@ namespace ZeroLevel.HNSW
private readonly NSWOptions<TItem> _options; private readonly NSWOptions<TItem> _options;
private readonly VectorSet<TItem> _vectors; private readonly VectorSet<TItem> _vectors;
private readonly Layer<TItem>[] _layers; private readonly Layer<TItem>[] _layers;
private int EntryPoint = 0;
private Layer<TItem> EnterPointsLayer => _layers[_layers.Length - 1]; private int MaxLayer = 0;
private Layer<TItem> LastLayer => _layers[0];
private int EntryPoint = -1;
private int MaxLayer = -1;
private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator; private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator;
private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim(); private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim();
@ -58,9 +55,12 @@ namespace ZeroLevel.HNSW
} }
} }
public IEnumerable<(int, TItem[])> Search(TItem vector, int k, HashSet<int> activeNodes = null) public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet<int> activeNodes = null)
{
foreach (var pair in KNearest(vector, k))
{ {
return Enumerable.Empty<(int, TItem[])>(); yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
} }
public int[] AddItems(IEnumerable<TItem> vectors) public int[] AddItems(IEnumerable<TItem> vectors)
@ -109,24 +109,25 @@ namespace ZeroLevel.HNSW
public void INSERT(int q) public void INSERT(int q)
{ {
var distance = new Func<int, float>(candidate => _options.Distance(_vectors[q], _vectors[candidate])); var distance = new Func<int, float>(candidate => _options.Distance(_vectors[q], _vectors[candidate]));
// W ← ∅ // list for the currently found nearest elements // W ← ∅ // list for the currently found nearest elements
IDictionary<int, float> W = new Dictionary<int, float>(); IDictionary<int, float> W = new Dictionary<int, float>();
// ep ← get enter point for hnsw // ep ← get enter point for hnsw
var ep = EntryPoint == -1 ? 0 : EntryPoint; var ep = EntryPoint;
var epDist = 0.0f; var epDist = distance(ep);
// L ← level of ep // top layer for hnsw // L ← level of ep // top layer for hnsw
var L = MaxLayer; var L = MaxLayer;
// l ← ⌊-ln(unif(0..1))∙mL⌋ // new elements level // l ← ⌊-ln(unif(0..1))∙mL⌋ // new elements level
int l = _layerLevelGenerator.GetRandomLayer(); int l = _layerLevelGenerator.GetRandomLayer();
if (L == -1)
{
L = l;
MaxLayer = l;
}
// for lc ← L … l+1 // for lc ← L … l+1
// Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа
for (int lc = L; lc > l; --lc) for (int lc = L; lc > l; --lc)
{
if (_layers[lc].CountLinks == 0)
{
_layers[lc].Append(q);
ep = q;
}
else
{ {
// W ← SEARCH-LAYER(q, ep, ef = 1, lc) // W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[lc].RunKnnAtLayer(ep, distance, W, 1); _layers[lc].RunKnnAtLayer(ep, distance, W, 1);
@ -136,20 +137,28 @@ namespace ZeroLevel.HNSW
epDist = nearest.Value; epDist = nearest.Value;
W.Clear(); W.Clear();
} }
}
//for lc ← min(L, l) … 0 //for lc ← min(L, l) … 0
// connecting new node to the small world // connecting new node to the small world
for (int lc = Math.Min(L, l); lc >= 0; --lc) for (int lc = Math.Min(L, l); lc >= 0; --lc)
{
if (_layers[lc].CountLinks == 0)
{
_layers[lc].Append(q);
ep = q;
}
else
{ {
// W ← SEARCH - LAYER(q, ep, efConstruction, lc) // W ← SEARCH - LAYER(q, ep, efConstruction, lc)
_layers[lc].RunKnnAtLayer(ep, distance, W, _options.EFConstruction); _layers[lc].RunKnnAtLayer(ep, distance, W, _options.EFConstruction);
// neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4
var neighbors = SelectBestForConnecting(lc, distance, W);; var neighbors = SelectBestForConnecting(lc, distance, W);
// add bidirectionall connectionts from neighbors to q at layer lc // add bidirectionall connectionts from neighbors to q at layer lc
// for each e ∈ neighbors // shrink connections if needed // for each e ∈ neighbors // shrink connections if needed
foreach (var e in neighbors) foreach (var e in neighbors)
{ {
// eConn ← neighbourhood(e) at layer lc // eConn ← neighbourhood(e) at layer lc
_layers[lc].AddBidirectionallConnectionts(q, e.Key, e.Value); _layers[lc].AddBidirectionallConnectionts(q, e.Key, e.Value, lc == 0);
// if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer
if (e.Value < epDist) if (e.Value < epDist)
{ {
@ -161,6 +170,7 @@ namespace ZeroLevel.HNSW
ep = W.OrderBy(p => p.Value).First().Key; ep = W.OrderBy(p => p.Value).First().Key;
W.Clear(); W.Clear();
} }
}
// if l > L // if l > L
if (l > L) if (l > L)
{ {

Loading…
Cancel
Save

Powered by TurnKey Linux.