diff --git a/TestHNSW/HNSWDemo/Tests/AccuracityTest.cs b/TestHNSW/HNSWDemo/Tests/AccuracityTest.cs index f36dc65..cbd4fc5 100644 --- a/TestHNSW/HNSWDemo/Tests/AccuracityTest.cs +++ b/TestHNSW/HNSWDemo/Tests/AccuracityTest.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; using ZeroLevel.HNSW; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -25,8 +26,8 @@ namespace HNSWDemo.Tests var sw = new Stopwatch(); - var test = new VectorsDirectCompare(samples, Metrics.Cosine); - var world = new SmallWorld(NSWOptions.Create(8, 12, 100, 100, Metrics.Cosine)); + var test = new VectorsDirectCompare(samples, Metrics.CosineDistance); + var world = new SmallWorld(NSWOptions.Create(8, 12, 100, 100, Metrics.CosineDistance)); sw.Start(); var ids = world.AddItems(samples.ToArray()); diff --git a/TestHNSW/HNSWDemo/Tests/AutoClusteringMNISTTest.cs b/TestHNSW/HNSWDemo/Tests/AutoClusteringMNISTTest.cs index eabfa4a..faa8c91 100644 --- a/TestHNSW/HNSWDemo/Tests/AutoClusteringMNISTTest.cs +++ b/TestHNSW/HNSWDemo/Tests/AutoClusteringMNISTTest.cs @@ -8,6 +8,7 @@ using System.Runtime.InteropServices; using ZeroLevel.HNSW; using ZeroLevel.HNSW.Services; using ZeroLevel.Services.FileSystem; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -58,7 +59,7 @@ namespace HNSWDemo.Tests vectors.Add(v); } } - var options = NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean); + var options = NSWOptions.Create(8, 16, 200, 200, Metrics.L2EuclideanDistance); SmallWorld world; if (File.Exists("graph_mnist.bin")) { @@ -77,7 +78,7 @@ namespace HNSWDemo.Tests } } - var distance = new Func((id1, id2) => Metrics.L2Euclidean(world.GetVector(id1), world.GetVector(id2))); + var distance = new Func((id1, id2) => Metrics.L2EuclideanDistance(world.GetVector(id1), world.GetVector(id2))); var links = world.GetLinks().SelectMany(pair => pair.Value.Select(p=> distance(pair.Key, p))).ToList(); var exists = links.Where(n => n > 0).ToArray(); diff --git a/TestHNSW/HNSWDemo/Tests/AutoClusteringTest.cs b/TestHNSW/HNSWDemo/Tests/AutoClusteringTest.cs index 9bdbbe0..2e77c21 100644 --- a/TestHNSW/HNSWDemo/Tests/AutoClusteringTest.cs +++ b/TestHNSW/HNSWDemo/Tests/AutoClusteringTest.cs @@ -1,6 +1,7 @@ using System; using ZeroLevel.HNSW; using ZeroLevel.HNSW.Services; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -13,7 +14,7 @@ namespace HNSWDemo.Tests public void Run() { var vectors = VectorUtils.RandomVectors(Dimensionality, Count); - var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean)); + var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2EuclideanDistance)); world.AddItems(vectors); var clusters = AutomaticGraphClusterer.DetectClusters(world); Console.WriteLine($"Found {clusters.Count} clusters"); diff --git a/TestHNSW/HNSWDemo/Tests/HistogramTest.cs b/TestHNSW/HNSWDemo/Tests/HistogramTest.cs index 9b803eb..e13375d 100644 --- a/TestHNSW/HNSWDemo/Tests/HistogramTest.cs +++ b/TestHNSW/HNSWDemo/Tests/HistogramTest.cs @@ -3,6 +3,7 @@ using System.Drawing; using System.IO; using System.Linq; using ZeroLevel.HNSW; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -28,10 +29,10 @@ namespace HNSWDemo.Tests private void Create(int dim, string output) { var vectors = VectorUtils.RandomVectors(dim, Count); - var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean)); + var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2EuclideanDistance)); world.AddItems(vectors); - var distance = new Func((id1, id2) => Metrics.L2Euclidean(world.GetVector(id1), world.GetVector(id2))); + var distance = new Func((id1, id2) => Metrics.L2EuclideanDistance(world.GetVector(id1), world.GetVector(id2))); var weights = world.GetLinks().SelectMany(pair => pair.Value.Select(id => distance(pair.Key, id))); var histogram = new Histogram(HistogramMode.SQRT, weights); histogram.Smooth(); diff --git a/TestHNSW/HNSWDemo/Tests/InsertTimeExplosionTest.cs b/TestHNSW/HNSWDemo/Tests/InsertTimeExplosionTest.cs index f451540..d4c7513 100644 --- a/TestHNSW/HNSWDemo/Tests/InsertTimeExplosionTest.cs +++ b/TestHNSW/HNSWDemo/Tests/InsertTimeExplosionTest.cs @@ -1,6 +1,7 @@ using System; using System.Diagnostics; using ZeroLevel.HNSW; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -14,7 +15,7 @@ namespace HNSWDemo.Tests public void Run() { var sw = new Stopwatch(); - var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, Metrics.Cosine)); + var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, Metrics.CosineDistance)); for (int i = 0; i < IterationCount; i++) { var samples = VectorUtils.RandomVectors(Dimensionality, Count); diff --git a/TestHNSW/HNSWDemo/Tests/LALTest.cs b/TestHNSW/HNSWDemo/Tests/LALTest.cs index 8f21151..dbeedcd 100644 --- a/TestHNSW/HNSWDemo/Tests/LALTest.cs +++ b/TestHNSW/HNSWDemo/Tests/LALTest.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using ZeroLevel.HNSW; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -20,7 +21,7 @@ namespace HNSWDemo.Tests var moda = 3; var persons = Person.GenerateRandom(dimensionality, count); var samples = new Dictionary>(); - var options = NSWOptions.Create(6, 8, 100, 100, Metrics.Cosine); + var options = NSWOptions.Create(6, 8, 100, 100, Metrics.CosineDistance); foreach (var p in persons) { diff --git a/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs b/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs index 3566779..ffa01e8 100644 --- a/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs +++ b/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using ZeroLevel.HNSW; using ZeroLevel.HNSW.Services; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -26,11 +27,11 @@ namespace HNSWDemo.Tests { var v1 = samples[i]; var v2 = samples[i + 1]; - var dist = Metrics.Cosine(v1, v2); + var dist = Metrics.CosineDistance(v1, v2); var qv1 = q_samples[i]; var qv2 = q_samples[i + 1]; - var qdist = Metrics.Cosine(qv1, qv2); + var qdist = Metrics.CosineDistance(qv1, qv2); list.Add(Math.Abs(dist - qdist)); } diff --git a/TestHNSW/HNSWDemo/Tests/QuantizeAccuracityTest.cs b/TestHNSW/HNSWDemo/Tests/QuantizeAccuracityTest.cs index 324a92d..7518257 100644 --- a/TestHNSW/HNSWDemo/Tests/QuantizeAccuracityTest.cs +++ b/TestHNSW/HNSWDemo/Tests/QuantizeAccuracityTest.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Linq; using ZeroLevel.HNSW; using ZeroLevel.HNSW.Services; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -28,8 +29,8 @@ namespace HNSWDemo.Tests var sw = new Stopwatch(); - var test = new VectorsDirectCompare(s, Metrics.Cosine); - var world = new SmallWorld(NSWOptions.Create(6, 8, 100, 100, Metrics.Cosine)); + var test = new VectorsDirectCompare(s, Metrics.CosineDistance); + var world = new SmallWorld(NSWOptions.Create(6, 8, 100, 100, Metrics.CosineDistance)); sw.Start(); var ids = world.AddItems(samples.ToArray()); diff --git a/TestHNSW/HNSWDemo/Tests/QuantizeHistogramTest.cs b/TestHNSW/HNSWDemo/Tests/QuantizeHistogramTest.cs index f71a33c..53fa40e 100644 --- a/TestHNSW/HNSWDemo/Tests/QuantizeHistogramTest.cs +++ b/TestHNSW/HNSWDemo/Tests/QuantizeHistogramTest.cs @@ -3,6 +3,7 @@ using System.Drawing; using System.Linq; using ZeroLevel.HNSW; using ZeroLevel.HNSW.Services; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -18,10 +19,10 @@ namespace HNSWDemo.Tests { var vectors = VectorUtils.RandomVectors(Dimensionality, Count); var q = new Quantizator(-1f, 1f); - var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.Cosine)); + var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.CosineDistance)); world.AddItems(vectors.Select(v => q.QuantizeToLong(v)).ToList()); - var distance = new Func((id1, id2) => Metrics.Cosine(world.GetVector(id1), world.GetVector(id2))); + var distance = new Func((id1, id2) => Metrics.CosineDistance(world.GetVector(id1), world.GetVector(id2))); var weights = world.GetLinks().SelectMany(pair => pair.Value.Select(id => distance(pair.Key, id))); var histogram = new Histogram(HistogramMode.SQRT, weights); histogram.Smooth(); diff --git a/TestHNSW/HNSWDemo/Tests/QuantizeInsertTimeExplosionTest.cs b/TestHNSW/HNSWDemo/Tests/QuantizeInsertTimeExplosionTest.cs index b42ddde..a0c6679 100644 --- a/TestHNSW/HNSWDemo/Tests/QuantizeInsertTimeExplosionTest.cs +++ b/TestHNSW/HNSWDemo/Tests/QuantizeInsertTimeExplosionTest.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Linq; using ZeroLevel.HNSW; using ZeroLevel.HNSW.Services; +using ZeroLevel.Services.Mathemathics; namespace HNSWDemo.Tests { @@ -16,7 +17,7 @@ namespace HNSWDemo.Tests public void Run() { var sw = new Stopwatch(); - var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, Metrics.Cosine)); + var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, Metrics.CosineDistance)); var q = new Quantizator(-1f, 1f); for (int i = 0; i < IterationCount; i++) { diff --git a/ZeroLevel.NN/Models/ImagePreprocessorOptions.cs b/ZeroLevel.NN/Models/ImagePreprocessorOptions.cs index 64c819c..2566676 100644 --- a/ZeroLevel.NN/Models/ImagePreprocessorOptions.cs +++ b/ZeroLevel.NN/Models/ImagePreprocessorOptions.cs @@ -2,6 +2,7 @@ { public class ImagePreprocessorOptions { + private const float PIXEL_NORMALIZATION_SCALE = 1f / 255f; public ImagePreprocessorOptions(int inputWidth, int inputHeight, PredictorChannelType channelType) { this.InputWidth = inputWidth; @@ -19,8 +20,12 @@ return this; } - public ImagePreprocessorOptions ApplyNormilization() + public ImagePreprocessorOptions ApplyNormilization(float? multiplier = null) { + if (multiplier.HasValue) + { + NormalizationMultiplier = multiplier.Value; + } this.Normalize = true; return this; } @@ -60,6 +65,8 @@ return this; } + + public float NormalizationMultiplier { get; private set; } = PIXEL_NORMALIZATION_SCALE; /// /// Channel type, if first tensor dims = [batch_index, channel, x, y], if last, dims = dims = [batch_index, x, y, channel] /// diff --git a/ZeroLevel.NN/Services/Clusterization/FeatureCluster.cs b/ZeroLevel.NN/Services/Clusterization/FeatureCluster.cs new file mode 100644 index 0000000..55cde57 --- /dev/null +++ b/ZeroLevel.NN/Services/Clusterization/FeatureCluster.cs @@ -0,0 +1,64 @@ +namespace ZeroLevel.NN.Clusterization +{ + public class FeatureCluster + { + private readonly List _features = new List(); + private readonly Func _vectorExtractor; + public FeatureCluster(Func vectorExtractor) + { + _vectorExtractor = vectorExtractor; + } + + public IReadOnlyList Features => _features; + + internal void Append(T face) => _features.Add(face); + public bool IsNeighbor(T feature, Func similarityFunction, float threshold, float clusterThreshold = 0.5f) + { + if (_features.Count == 0) return true; + if (_features.Count == 1) + { + var similarity = similarityFunction(_vectorExtractor(feature), _vectorExtractor(_features[0])); + return similarity >= threshold; + } + var clusterNearestElementsCount = 0; + foreach (var f in _features) + { + var similarity = similarityFunction(_vectorExtractor(feature), _vectorExtractor(f)); + if (similarity >= threshold) + { + clusterNearestElementsCount++; + } + } + var clusterToFaceScore = (float)clusterNearestElementsCount / (float)_features.Count; + return clusterToFaceScore > clusterThreshold; + } + + public bool IsNeighborCluster(FeatureCluster cluster, Func similarityFunction, float threshold, float clusterThreshold = 0.5f) + { + if (_features.Count == 0) return true; + if (_features.Count == 1 && cluster.IsNeighbor(_features[0], similarityFunction, threshold, clusterThreshold)) + { + return true; + } + var clusterNearestElementsCount = 0; + foreach (var f in _features) + { + if (cluster.IsNeighbor(f, similarityFunction, threshold, clusterThreshold)) + { + clusterNearestElementsCount++; + } + } + var localCount = _features.Count; + var remoteCount = cluster.Features.Count; + var localIntersection = (float)clusterNearestElementsCount / (float)localCount; + var remoteIntersection = (float)clusterNearestElementsCount / (float)remoteCount; + var score = Math.Max(localIntersection, remoteIntersection); + return score > clusterThreshold; + } + + public void Merge(FeatureCluster other) + { + this._features.AddRange(other.Features); + } + } +} diff --git a/ZeroLevel.NN/Services/Clusterization/FeatureClusterBulder.cs b/ZeroLevel.NN/Services/Clusterization/FeatureClusterBulder.cs new file mode 100644 index 0000000..800c304 --- /dev/null +++ b/ZeroLevel.NN/Services/Clusterization/FeatureClusterBulder.cs @@ -0,0 +1,65 @@ +namespace ZeroLevel.NN.Clusterization +{ + public class FeatureClusterBulder + { + public FeatureClusterCollection Build(IEnumerable faces, Func vectorExtractor, Func similarityFunction, float threshold, float clusterThreshold = 0.5f) + { + var collection = new FeatureClusterCollection(); + foreach (var face in faces) + { + bool isAdded = false; + foreach (var cluster in collection.Clusters) + { + if (cluster.Value.IsNeighbor(face, similarityFunction, threshold, clusterThreshold)) + { + cluster.Value.Append(face); + isAdded = true; + break; + } + } + if (false == isAdded) + { + var cluster = new FeatureCluster(vectorExtractor); + cluster.Append(face); + collection.Add(cluster); + } + } + MergeClusters(collection, similarityFunction, threshold, clusterThreshold); + return collection; + } + + private void MergeClusters(FeatureClusterCollection collection, Func similarityFunction, float threshold, float clusterThreshold = 0.5f) + { + int lastCount = collection.Clusters.Count; + var removed = new Queue(); + do + { + var ids = collection.Clusters.Keys.ToList(); + for (var i = 0; i < ids.Count - 1; i++) + { + for (var j = i + 1; j < ids.Count; j++) + { + var c1 = collection.Clusters[ids[i]]; + var c2 = collection.Clusters[ids[j]]; + if (c1.IsNeighborCluster(c2, similarityFunction, threshold, clusterThreshold)) + { + c1.Merge(c2); + removed.Enqueue(ids[j]); + ids.RemoveAt(j); + j--; + } + } + } + while (removed.Count > 0) + { + collection.Clusters.Remove(removed.Dequeue()); + } + if (lastCount == collection.Clusters.Count) + { + break; + } + lastCount = collection.Clusters.Count; + } while (true); + } + } +} diff --git a/ZeroLevel.NN/Services/Clusterization/FeatureClusterCollection.cs b/ZeroLevel.NN/Services/Clusterization/FeatureClusterCollection.cs new file mode 100644 index 0000000..92b052c --- /dev/null +++ b/ZeroLevel.NN/Services/Clusterization/FeatureClusterCollection.cs @@ -0,0 +1,15 @@ +namespace ZeroLevel.NN.Clusterization +{ + public class FeatureClusterCollection + { + private int _clusterKey = 0; + private IDictionary> _clusters = new Dictionary>(); + + public IDictionary> Clusters => _clusters; + + internal void Add(FeatureCluster cluster) + { + _clusters.Add(Interlocked.Increment(ref _clusterKey), cluster); + } + } +} diff --git a/ZeroLevel.NN/Services/ImagePreprocessor.cs b/ZeroLevel.NN/Services/ImagePreprocessor.cs index c40e4e5..b5e824e 100644 --- a/ZeroLevel.NN/Services/ImagePreprocessor.cs +++ b/ZeroLevel.NN/Services/ImagePreprocessor.cs @@ -8,8 +8,6 @@ namespace ZeroLevel.NN { public static class ImagePreprocessor { - private const float NORMALIZATION_SCALE = 1f / 255f; - private static Func PixelToTensorMethod(ImagePreprocessorOptions options) { if (options.Normalize) @@ -18,16 +16,16 @@ namespace ZeroLevel.NN { if (options.CorrectionFunc == null) { - return new Func((b, i) => ((NORMALIZATION_SCALE * (float)b) - options.Mean[i]) / options.Std[i]); + return new Func((b, i) => ((options.NormalizationMultiplier * (float)b) - options.Mean[i]) / options.Std[i]); } else { - return new Func((b, i) => options.CorrectionFunc.Invoke(i, NORMALIZATION_SCALE * (float)b)); + return new Func((b, i) => options.CorrectionFunc.Invoke(i, options.NormalizationMultiplier * (float)b)); } } else { - return new Func((b, i) => NORMALIZATION_SCALE * (float)b); + return new Func((b, i) => options.NormalizationMultiplier * (float)b); } } else if (options.Correction) @@ -58,6 +56,7 @@ namespace ZeroLevel.NN } return count; } + private static void FillTensor(Tensor tensor, Image image, int index, ImagePreprocessorOptions options, Func pixToTensor) { var append = options.ChannelType == PredictorChannelType.ChannelFirst diff --git a/ZeroLevel/Services/Extensions/ArrayExtensions.cs b/ZeroLevel/Services/Extensions/ArrayExtensions.cs index 0cc355d..f8413a5 100644 --- a/ZeroLevel/Services/Extensions/ArrayExtensions.cs +++ b/ZeroLevel/Services/Extensions/ArrayExtensions.cs @@ -99,7 +99,11 @@ namespace ZeroLevel if (ReferenceEquals(first, second)) return true; if (first.Length != second.Length) return false; - return Array.Equals(first, second); + for (int i = 0; i < first.Length; i++) + { + if (first[i] != second[i]) return false; + } + return true; } } } \ No newline at end of file diff --git a/ZeroLevel/Services/Mathemathics/Metrics.cs b/ZeroLevel/Services/Mathemathics/Metrics.cs index e63c2b2..a41daf1 100644 --- a/ZeroLevel/Services/Mathemathics/Metrics.cs +++ b/ZeroLevel/Services/Mathemathics/Metrics.cs @@ -4,7 +4,7 @@ namespace ZeroLevel.Services.Mathemathics { public enum KnownMetrics { - Cosine, Manhattanm, Euclide, Chebyshev + Cosine, Manhattanm, Euclide, Chebyshev, DotProduct } @@ -22,6 +22,8 @@ namespace ZeroLevel.Services.Mathemathics return new Func((u, v) => ChebyshevDistance(u, v)); case KnownMetrics.Manhattanm: return new Func((u, v) => L1ManhattanDistance(u, v)); + case KnownMetrics.DotProduct: + return new Func((u, v) => DotProductDistance(u, v)); } throw new Exception($"Metric '{metric.ToString()}' not supported for Float type"); } @@ -37,8 +39,9 @@ namespace ZeroLevel.Services.Mathemathics case KnownMetrics.Chebyshev: return new Func((u, v) => ChebyshevDistance(u, v)); case KnownMetrics.Manhattanm: - return new Func((u, v) => L1ManhattanDistance - (u, v)); + return new Func((u, v) => L1ManhattanDistance(u, v)); + case KnownMetrics.DotProduct: + return new Func((u, v) => DotProductDistance(u, v)); } throw new Exception($"Metric '{metric.ToString()}' not supported for Byte type"); } @@ -55,6 +58,8 @@ namespace ZeroLevel.Services.Mathemathics return new Func((u, v) => ChebyshevDistance(u, v)); case KnownMetrics.Manhattanm: return new Func((u, v) => L1ManhattanDistance(u, v)); + case KnownMetrics.DotProduct: + return new Func((u, v) => DotProductDistance(u, v)); } throw new Exception($"Metric '{metric.ToString()}' not supported for Long type"); } @@ -71,6 +76,8 @@ namespace ZeroLevel.Services.Mathemathics return new Func((u, v) => ChebyshevDistance(u, v)); case KnownMetrics.Manhattanm: return new Func((u, v) => L1ManhattanDistance(u, v)); + case KnownMetrics.DotProduct: + return new Func((u, v) => DotProductDistance(u, v)); } throw new Exception($"Metric '{metric.ToString()}' not supported for Int type"); } @@ -434,12 +441,44 @@ namespace ZeroLevel.Services.Mathemathics return 1 - similarity; } - public static float CosineClipped(float[] u, float[] v, float min, float max) + public static double DotProductDistance(float[] e1, float[] e2) { - var similarity = CosineDistance(u, v); - if (min > similarity) similarity = min; - if (max < similarity) similarity = max; - return similarity; + var sim = 0f; + for (int i = 0; i < e1.Length; i++) + { + sim += e1[i] * e2[i]; + } + return sim; + } + + public static double DotProductDistance(byte[] e1, byte[] e2) + { + var sim = 0f; + for (int i = 0; i < e1.Length; i++) + { + sim += e1[i] * e2[i]; + } + return sim; + } + + public static double DotProductDistance(int[] e1, int[] e2) + { + var sim = 0f; + for (int i = 0; i < e1.Length; i++) + { + sim += e1[i] * e2[i]; + } + return sim; + } + + public static double DotProductDistance(long[] e1, long[] e2) + { + var sim = 0f; + for (int i = 0; i < e1.Length; i++) + { + sim += e1[i] * e2[i]; + } + return sim; } } }