using System; using System.Collections.Generic; using System.IO; using System.Linq; using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { public class HNSWMappers : IBinarySerializable { private IDictionary> _mappers; private readonly Func _bucketFunction; public HNSWMappers(string filePath, Func bucketFunction) { _bucketFunction = bucketFunction; using (var fs = File.OpenRead(filePath)) { using (var bs = new BufferedStream(fs, 1024 * 1024 * 32)) { using (var reader = new MemoryStreamReader(bs)) { Deserialize(reader); } } } } public void Save(string filePath) { using (var fs = File.OpenWrite(filePath)) { using (var bs = new BufferedStream(fs, 1024 * 1024 * 32)) { using (var writer = new MemoryStreamWriter(bs)) { Serialize(writer); } } } } public HNSWMappers(Func bucketFunction) { _mappers = new Dictionary>(); _bucketFunction = bucketFunction; } public void Append(HNSWMap map, int c) { _mappers.Add(c, map); } public IEnumerable ConvertIdsToFeatures(int c, IEnumerable ids) { foreach (var feature in _mappers[c].ConvertIdsToFeatures(ids)) { yield return feature; } } public IDictionary CreateContext(IEnumerable activeNodes, IEnumerable entryPoints) { var actives = new Dictionary>(); var entries = new Dictionary>(); if (activeNodes != null) { foreach (var node in activeNodes) { var c = _bucketFunction(node); if (_mappers.ContainsKey(c)) { if (actives.ContainsKey(c) == false) { actives.Add(c, new List()); } actives[c].Add(_mappers[c][node]); } else { Log.Warning($"Active node {node} is not included in graphs!"); } } } if (entryPoints != null) { foreach (var entryPoint in entryPoints) { var c = _bucketFunction(entryPoint); if (_mappers.ContainsKey(c)) { if (entries.ContainsKey(c) == false) { entries.Add(c, new List()); } entries[c].Add(_mappers[c][entryPoint]); } else { Log.Warning($"Entry point {entryPoint} is not included in graphs!"); } } } var result = new Dictionary(); foreach (var pair in _mappers) { var active = actives.GetValueOrDefault(pair.Key); var entry = entries.GetValueOrDefault(pair.Key); result.Add(pair.Key, new SearchContext().SetActiveNodes(active).SetEntryPointsNodes(entry)); } var total = result.Values.Sum(v => v.AvaliableNodesCount); if (total > 0) { foreach (var pair in result) { pair.Value.CaclulatePercentage(total); } } else { //total = result.Values.Sum(v => v.EntryPoints.Count()); foreach (var pair in result) { //var p = (double)pair.Value.EntryPoints.Count() / (double)total; pair.Value.SetPercentage(0.2d); } } return result; } public void Deserialize(IBinaryReader reader) { this._mappers = reader.ReadDictionary>(); } public void Serialize(IBinaryWriter writer) { writer.WriteDictionary>(this._mappers); } } }