HNSW update

Fix bugs
pull/1/head
unknown 3 years ago
parent 1f967908b9
commit 0e2f693b87

@ -77,13 +77,13 @@ namespace HNSWDemo
static void Main(string[] args) static void Main(string[] args)
{ {
var dimensionality = 128; var dimensionality = 128;
var testCount = 1000; var testCount = 5000;
var count = 100000; var count = 10000;
var batchSize = 5000; var batchSize = 1000;
var samples = Person.GenerateRandom(dimensionality, count); var samples = Person.GenerateRandom(dimensionality, count);
var sw = new Stopwatch(); var sw = new Stopwatch();
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 4, 120, 120, CosineDistance.ForUnits)); var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 10, 200, 200, CosineDistance.ForUnits, false, false, selectionHeuristic: NeighbourSelectionHeuristic.SelectHeuristic));
for (int i = 0; i < (count / batchSize); i++) for (int i = 0; i < (count / batchSize); i++)
{ {

@ -94,7 +94,9 @@ namespace ZeroLevel.HNSW
// W ← ep // dynamic list of found nearest neighbors // W ← ep // dynamic list of found nearest neighbors
W.Add(entryPointId, C[entryPointId]); W.Add(entryPointId, C[entryPointId]);
var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); });
var fartherFromResult = new Func<(int, float)>(() => { var pair = W.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); });
var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); });
// run bfs // run bfs
while (C.Count > 0) while (C.Count > 0)
{ {
@ -164,14 +166,14 @@ namespace ZeroLevel.HNSW
/// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param> /// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param>
/// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param> /// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param>
/// <returns>Output: M elements selected by the heuristic</returns> /// <returns>Output: M elements selected by the heuristic</returns>
public IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, IDictionary<int, float> candidates, int M, bool extendCandidates, bool keepPrunedConnections) public IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, IDictionary<int, float> candidates, int M)
{ {
// R ← ∅ // R ← ∅
var R = new Dictionary<int, float>(); var R = new Dictionary<int, float>();
// W ← C // working queue for the candidates // W ← C // working queue for the candidates
var W = new Dictionary<int, float>(candidates); var W = new Dictionary<int, float>(candidates);
// if extendCandidates // extend candidates by their neighbors // if extendCandidates // extend candidates by their neighbors
if (extendCandidates) if (_options.ExpandBestSelection)
{ {
var extendBuffer = new HashSet<int>(); var extendBuffer = new HashSet<int>();
// for each e ∈ C // for each e ∈ C
@ -191,7 +193,7 @@ namespace ZeroLevel.HNSW
// W ← W eadj // W ← W eadj
foreach (var id in extendBuffer) foreach (var id in extendBuffer)
{ {
W.Add(id, distance(id)); W[id] = distance(id);
} }
} }
@ -201,7 +203,7 @@ namespace ZeroLevel.HNSW
var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); }); var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); });
var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); }); var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); });
var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); }); var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); Wd.Remove(pair.Key); return (pair.Key, pair.Value); });
// while │W│ > 0 and │R│< M // while │W│ > 0 and │R│< M
@ -218,21 +220,21 @@ namespace ZeroLevel.HNSW
// R ← R e // R ← R e
R.Add(e, ed); R.Add(e, ed);
} }
// else else
{ {
// Wd ← Wd e // Wd ← Wd e
Wd.Add(e, ed); Wd.Add(e, ed);
} }
// if keepPrunedConnections // add some of the discarded // connections from Wd }
if (keepPrunedConnections) // if keepPrunedConnections // add some of the discarded // connections from Wd
if (_options.KeepPrunedConnections)
{
// while │Wd│> 0 and │R│< M
while (Wd.Count > 0 && R.Count < M)
{ {
// while │Wd│> 0 and │R│< M // R ← R extract nearest element from Wd to q
while (Wd.Count > 0 && R.Count < M) var nearest = popNearestDiscarded();
{ R[nearest.Item1] = nearest.Item2;
// R ← R extract nearest element from Wd to q
var nearest = popNearestDiscarded();
R.Add(nearest.Item1, nearest.Item2);
}
} }
} }
// return R // return R

@ -2,10 +2,24 @@
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {
public sealed class NSWOptions<TItem> /// <summary>
/// Type of heuristic to select best neighbours for a node.
/// </summary>
public enum NeighbourSelectionHeuristic
{ {
public const int FARTHEST_DIVIDER = 3; /// <summary>
/// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in <see cref="Algorithms.Algorithm3{TItem, TDistance}"/>
/// </summary>
SelectSimple,
/// <summary>
/// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in <see cref="Algorithms.Algorithm4{TItem, TDistance}"/>
/// </summary>
SelectHeuristic
}
public sealed class NSWOptions<TItem>
{
/// <summary> /// <summary>
/// Mox node connections on Layer /// Mox node connections on Layer
/// </summary> /// </summary>
@ -24,19 +38,42 @@ namespace ZeroLevel.HNSW
/// </summary> /// </summary>
public readonly Func<TItem, TItem, float> Distance; public readonly Func<TItem, TItem, float> Distance;
public readonly bool ExpandBestSelection;
public readonly bool KeepPrunedConnections;
public readonly NeighbourSelectionHeuristic SelectionHeuristic;
public readonly int LayersCount; public readonly int LayersCount;
private NSWOptions(int layersCount, int m, int ef, int ef_construction, Func<TItem, TItem, float> distance) private NSWOptions(int layersCount,
int m,
int ef,
int ef_construction,
Func<TItem, TItem, float> distance,
bool expandBestSelection,
bool keepPrunedConnections,
NeighbourSelectionHeuristic selectionHeuristic)
{ {
LayersCount = layersCount; LayersCount = layersCount;
M = m; M = m;
EF = ef; EF = ef;
EFConstruction = ef_construction; EFConstruction = ef_construction;
Distance = distance; Distance = distance;
ExpandBestSelection = expandBestSelection;
KeepPrunedConnections = keepPrunedConnections;
SelectionHeuristic = selectionHeuristic;
} }
public static NSWOptions<TItem> Create(int layersCount, int M, int EF, int EF_construction, Func<TItem, TItem, float> distance) => public static NSWOptions<TItem> Create(int layersCount,
new NSWOptions<TItem>(layersCount, M, EF, EF_construction, distance); int M,
int EF,
int EF_construction,
Func<TItem, TItem, float> distance,
bool expandBestSelection = false,
bool keepPrunedConnections = false,
NeighbourSelectionHeuristic selectionHeuristic = NeighbourSelectionHeuristic.SelectSimple) =>
new NSWOptions<TItem>(layersCount, M, EF, EF_construction, distance, expandBestSelection, keepPrunedConnections, selectionHeuristic);
} }
} }

@ -14,15 +14,14 @@ namespace ZeroLevel.HNSW
private SortedList<long, float> _set = new SortedList<long, float>(); private SortedList<long, float> _set = new SortedList<long, float>();
public (int, int, float) this[int index] public (int, int) this[int index]
{ {
get get
{ {
var k = _set.Keys[index]; var k = _set.Keys[index];
var d = _set.Values[index];
var id1 = (int)(k >> HALF_LONG_BITS); var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
return (id1, id2, d); return (id1, id2);
} }
} }
@ -72,10 +71,22 @@ namespace ZeroLevel.HNSW
{ {
_set.Remove(k_id1_id2); _set.Remove(k_id1_id2);
_set.Remove(k_id2_id1); _set.Remove(k_id2_id1);
_set.Add(k_id_id1, distanceToId1); if (!_set.ContainsKey(k_id_id1))
_set.Add(k_id1_id, distanceToId1); {
_set.Add(k_id_id2, distanceToId2); _set.Add(k_id_id1, distanceToId1);
_set.Add(k_id2_id, distanceToId2); }
if (!_set.ContainsKey(k_id1_id))
{
_set.Add(k_id1_id, distanceToId1);
}
if (!_set.ContainsKey(k_id_id2))
{
_set.Add(k_id_id2, distanceToId2);
}
if (!_set.ContainsKey(k_id2_id))
{
_set.Add(k_id2_id, distanceToId2);
}
} }
finally finally
{ {

@ -192,7 +192,9 @@ namespace ZeroLevel.HNSW
private IDictionary<int, float> SelectBestForConnecting(int layer, Func<int, float> distance, IDictionary<int, float> candidates) private IDictionary<int, float> SelectBestForConnecting(int layer, Func<int, float> distance, IDictionary<int, float> candidates)
{ {
return _layers[layer].SELECT_NEIGHBORS_SIMPLE(distance, candidates, GetM(layer)); if(_options.SelectionHeuristic == NeighbourSelectionHeuristic.SelectSimple)
return _layers[layer].SELECT_NEIGHBORS_SIMPLE(distance, candidates, GetM(layer));
return _layers[layer].SELECT_NEIGHBORS_HEURISTIC(distance, candidates, GetM(layer));
} }
/// <summary> /// <summary>

Loading…
Cancel
Save

Powered by TurnKey Linux.