You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Zero/ZeroLevel.HNSW/SmallWorld.cs

380 lines
13 KiB

3 years ago
using System;
using System.Collections.Generic;
using System.IO;
3 years ago
using System.Linq;
using System.Threading;
using ZeroLevel.HNSW.Services;
using ZeroLevel.Services.Serialization;
3 years ago
namespace ZeroLevel.HNSW
{
public class SmallWorld<TItem>
{
private readonly NSWOptions<TItem> _options;
private VectorSet<TItem> _vectors;
private Layer<TItem>[] _layers;
3 years ago
private int EntryPoint = 0;
private int MaxLayer = 0;
private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator;
private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim();
3 years ago
public readonly Func<int, int, float> DistanceFunction;
public TItem GetVector(int id) => _vectors[id];
public IDictionary<int, HashSet<int>> GetLinks() => _layers[0].Links;
3 years ago
public SmallWorld(NSWOptions<TItem> options)
{
_options = options;
_vectors = new VectorSet<TItem>();
_layers = new Layer<TItem>[_options.LayersCount];
_layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M);
DistanceFunction = new Func<int, int, float>((id1, id2) => _options.Distance(_vectors[id1], _vectors[id2]));
3 years ago
for (int i = 0; i < _options.LayersCount; i++)
{
_layers[i] = new Layer<TItem>(_options, _vectors, i == 0);
3 years ago
}
}
public SmallWorld(NSWOptions<TItem> options, Stream stream)
{
_options = options;
_layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M);
DistanceFunction = new Func<int, int, float>((id1, id2) => _options.Distance(_vectors[id1], _vectors[id2]));
Deserialize(stream);
}
/// <summary>
/// Search in the graph K for vectors closest to a given vector
/// </summary>
/// <param name="vector">Given vector</param>
/// <param name="k">Count of elements for search</param>
/// <param name="activeNodes"></param>
/// <returns></returns>
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k)
3 years ago
{
foreach (var pair in KNearest(vector, k))
3 years ago
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
3 years ago
}
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, SearchContext context)
{
if (context == null)
{
foreach (var pair in KNearest(vector, k))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
else
{
foreach (var pair in KNearest(vector, k, context))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
}
/*
public IEnumerable<(int, TItem, float)> Search(int k, SearchContext context)
{
if (context == null)
{
throw new ArgumentNullException(nameof(context));
}
else
{
foreach (var pair in KNearest(k, context))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
}
*/
/// <summary>
/// Adding vectors batch
/// </summary>
/// <param name="vectors">Vectors</param>
/// <returns>Vector identifiers in a graph</returns>
3 years ago
public int[] AddItems(IEnumerable<TItem> vectors)
{
_lockGraph.EnterWriteLock();
try
3 years ago
{
var ids = _vectors.Append(vectors);
for (int i = 0; i < ids.Length; i++)
{
INSERT(ids[i]);
}
return ids;
}
finally
{
_lockGraph.ExitWriteLock();
3 years ago
}
}
#region https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary>
/// Algorithm 1
/// </summary>
private void INSERT(int q)
3 years ago
{
var distance = new Func<int, float>(candidate => _options.Distance(_vectors[q], _vectors[candidate]));
3 years ago
// W ← ∅ // list for the currently found nearest elements
var W = new MinHeap(_options.EFConstruction + 1);
3 years ago
// ep ← get enter point for hnsw
var ep = _layers[MaxLayer].FindEntryPointAtLayer(distance);
if (ep == -1)
ep = EntryPoint;
3 years ago
var epDist = distance(ep);
3 years ago
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
3 years ago
// l ← ⌊-ln(unif(0..1))∙mL⌋ // new elements level
int l = _layerLevelGenerator.GetRandomLayer();
// Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа
int id;
float value;
// for lc ← L … l+1
for (int lc = L; lc > l; --lc)
3 years ago
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
foreach (var i in _layers[lc].KNearestAtLayer(ep, distance, 1))
3 years ago
{
W.Push(i);
3 years ago
}
// ep ← get the nearest element from W to q
if (W.TryPeek(out id, out value))
3 years ago
{
ep = id;
epDist = value;
3 years ago
}
W.Clear();
3 years ago
}
//for lc ← min(L, l) … 0
// connecting new node to the small world
for (int lc = Math.Min(L, l); lc >= 0; --lc)
3 years ago
{
_layers[lc].Push(q, ep, W, distance);
// ep ← W
if (W.TryPeek(out id, out value))
3 years ago
{
ep = id;
epDist = value;
3 years ago
}
W.Clear();
}
// if l > L
if (l > L)
{
// set enter point for hnsw to q
L = l;
MaxLayer = l;
EntryPoint = ep;
3 years ago
}
}
public void TestWorld()
{
for (var v = 0; v < _vectors.Count; v++)
{
var nearest = _layers[0][v].ToArray();
if (nearest.Length > _layers[0].M)
{
Console.WriteLine($"V{v}. Count of links ({nearest.Length}) more than max ({_layers[0].M})");
}
}
// coverage test
var ep = 0;
var visited = new HashSet<int>();
var next = new Stack<int>();
next.Push(ep);
while (next.Count > 0)
{
ep = next.Pop();
visited.Add(ep);
foreach (var n in _layers[0].GetNeighbors(ep))
{
if (visited.Contains(n) == false)
{
next.Push(n);
}
}
}
if (visited.Count != _vectors.Count)
{
Console.Write($"Vectors count ({_vectors.Count}) less than BFS visited nodes count ({visited.Count})");
}
3 years ago
}
/// <summary>
/// Algorithm 5
/// </summary>
private IEnumerable<(int, float)> KNearest(TItem q, int k)
{
_lockGraph.EnterReadLock();
try
{
if (_vectors.Count == 0)
{
return Enumerable.Empty<(int, float)>();
}
int id;
float value;
var distance = new Func<int, float>(candidate => _options.Distance(q, _vectors[candidate]));
// W ← ∅ // set for the current nearest elements
var W = new MinHeap(k + 1);
// ep ← get enter point for hnsw
var ep = EntryPoint;
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
// for lc ← L … 1
for (int layer = L; layer > 0; --layer)
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
foreach (var i in _layers[layer].KNearestAtLayer(ep, distance, 1))
{
W.Push(i);
}
// ep ← get nearest element from W to q
if (W.TryPeek(out id, out value))
{
ep = id;
}
W.Clear();
}
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
foreach (var i in _layers[0].KNearestAtLayer(ep, distance, k))
{
W.Push(i);
}
// return K nearest elements from W to q
return W;
}
finally
{
_lockGraph.ExitReadLock();
}
}
private IEnumerable<(int, float)> KNearest(TItem q, int k, SearchContext context)
3 years ago
{
_lockGraph.EnterReadLock();
try
3 years ago
{
if (_vectors.Count == 0)
{
return Enumerable.Empty<(int, float)>();
}
int id;
float value;
var distance = new Func<int, float>(candidate => _options.Distance(q, _vectors[candidate]));
// W ← ∅ // set for the current nearest elements
var W = new MinHeap(k + 1);
// ep ← get enter point for hnsw
var ep = EntryPoint;
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
// for lc ← L … 1
for (int layer = L; layer > 0; --layer)
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
foreach (var i in _layers[layer].KNearestAtLayer(ep, distance, 1))
{
W.Push(i);
}
// ep ← get nearest element from W to q
if (W.TryPeek(out id, out value))
{
ep = id;
}
W.Clear();
}
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
foreach (var i in _layers[0].KNearestAtLayer(ep, distance, k, context))
{
W.Push(i);
}
// return K nearest elements from W to q
return W;
}
finally
{
_lockGraph.ExitReadLock();
3 years ago
}
}
/*
private IEnumerable<(int, float)> KNearest(int k, SearchContext context)
{
_lockGraph.EnterReadLock();
try
{
if (_vectors.Count == 0)
{
return Enumerable.Empty<(int, float)>();
}
// W ← ∅ // set for the current nearest elements
var W = new MinHeap(k + 1);
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
foreach (var i in _layers[0].KNearestAtLayer(W, k, context))
{
W.Push(i);
}
// return K nearest elements from W to q
return W;
}
finally
{
_lockGraph.ExitReadLock();
}
}
*/
3 years ago
#endregion
public void Serialize(Stream stream)
{
using (var writer = new MemoryStreamWriter(stream))
{
writer.WriteInt32(EntryPoint);
writer.WriteInt32(MaxLayer);
_vectors.Serialize(writer);
writer.WriteInt32(_layers.Length);
foreach (var l in _layers)
{
l.Serialize(writer);
}
}
}
public void Deserialize(Stream stream)
{
using (var reader = new MemoryStreamReader(stream))
{
this.EntryPoint = reader.ReadInt32();
this.MaxLayer = reader.ReadInt32();
_vectors = new VectorSet<TItem>();
_vectors.Deserialize(reader);
var countLayers = reader.ReadInt32();
_layers = new Layer<TItem>[countLayers];
for (int i = 0; i < countLayers; i++)
{
_layers[i] = new Layer<TItem>(_options, _vectors, i == 0);
_layers[i].Deserialize(reader);
}
}
}
3 years ago
}
}

Powered by TurnKey Linux.