diff --git a/TestApp/Program.cs b/TestApp/Program.cs index 04923dd..befde05 100644 --- a/TestApp/Program.cs +++ b/TestApp/Program.cs @@ -1,4 +1,5 @@ using Newtonsoft.Json; +using System; using ZeroLevel; using ZeroLevel.Logging; diff --git a/ZeroLevel/Services/Semantic/CValue/Document.cs b/ZeroLevel/Services/Semantic/CValue/Document.cs new file mode 100644 index 0000000..6a30e64 --- /dev/null +++ b/ZeroLevel/Services/Semantic/CValue/Document.cs @@ -0,0 +1,73 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace ZeroLevel.Services.Semantic.CValue +{ + public class Document + { + private String path; + private List sentenceList; + private List> tokenList; + private String name; + private List termList; + + public Document(String pPath, String pName) + { + path = pPath; + name = pName; + termList = new List(); + } + + public String getPath() + { + return path; + } + + public void setPath(String path) + { + this.path = path; + } + + public List getSentenceList() + { + return sentenceList; + } + + public void setSentenceList(List sentenceList) + { + this.sentenceList = sentenceList; + } + + public List> getTokenList() + { + return tokenList; + } + + public void List(List> tokenList) + { + this.tokenList = tokenList; + } + + public String getName() + { + return name; + } + + public void setName(String name) + { + this.name = name; + } + + public List getTermList() + { + return termList; + } + + public void setTermList(List termList) + { + this.termList = termList; + } + } +} diff --git a/ZeroLevel/Services/Semantic/CValue/Term.cs b/ZeroLevel/Services/Semantic/CValue/Term.cs new file mode 100644 index 0000000..788fb26 --- /dev/null +++ b/ZeroLevel/Services/Semantic/CValue/Term.cs @@ -0,0 +1,72 @@ +using System; + +namespace ZeroLevel.Services.Semantic.CValue +{ + public class Term + { + private String term; + private float score; + + + public Term() + { + + } + + public Term(String pTerm) + { + term = pTerm; + score = -1; + } + + public Term(String pTerm, float pScore) + { + term = pTerm; + score = pScore; + } + + public String getTerm() + { + return term; + } + + public void setTerm(String term) + { + this.term = term; + } + + public float getScore() + { + return score; + } + + public void setScore(float score) + { + this.score = score; + } + + public override string ToString() + { + return term + "\t" + score; + } + + + public override bool Equals(object obj) + { + return Equals(obj as Term); + } + + private bool Equals(Term other) + { + if (other == null) return false; + return this.term.Equals(other.term, StringComparison.OrdinalIgnoreCase); + } + + public override int GetHashCode() + { + int hash = 7; + hash = 97 * hash + this.term.GetHashCode(); + return hash; + } + } +} diff --git a/ZeroLevel/Services/Semantic/CValue/Token.cs b/ZeroLevel/Services/Semantic/CValue/Token.cs new file mode 100644 index 0000000..8af34e4 --- /dev/null +++ b/ZeroLevel/Services/Semantic/CValue/Token.cs @@ -0,0 +1,83 @@ +using System; + +namespace ZeroLevel.Services.Semantic.CValue +{ + public class Token + { + private String wordForm; + private String posTag; + private String chunkerTag; + private String lemma; + private int pos; //position inside the sentence? + + public Token(String pWordForm) + { + wordForm = pWordForm; + } + + public Token(String pWordForm, String pPostag) + { + wordForm = pWordForm; + posTag = pPostag; + } + + public Token(String pWordForm, String pPostag, String pLemma) + { + wordForm = pWordForm; + posTag = pPostag; + lemma = pLemma; + } + + public Token(String pWordForm, String pPostag, String pLemma, String pChunker) + { + wordForm = pWordForm; + posTag = pPostag; + lemma = pLemma; + chunkerTag = pChunker; + } + + public String getWordForm() + { + return wordForm; + } + + public void setWordForm(String wordForm) + { + this.wordForm = wordForm; + } + + public String getPosTag() + { + return posTag; + } + + public void setPosTag(String posTag) + { + this.posTag = posTag; + } + + public override string ToString() + { + return wordForm + "\t" + posTag; + } + + public String getLemma() + { + return lemma; + } + + public void setLemma(String lemma) + { + this.lemma = lemma; + } + + public String getChunkerTag() + { + return chunkerTag; + } + + public void setChunkerTag(String chunkerTag) + { + this.chunkerTag = chunkerTag; + } + } diff --git a/ZeroLevel/Services/Semantic/Fasttext/FTArgs.cs b/ZeroLevel/Services/Semantic/Fasttext/FTArgs.cs index 7fd1ad8..61864a6 100644 --- a/ZeroLevel/Services/Semantic/Fasttext/FTArgs.cs +++ b/ZeroLevel/Services/Semantic/Fasttext/FTArgs.cs @@ -1,8 +1,16 @@ -namespace ZeroLevel.Services.Semantic.Fasttext +using System; +using System.Text; +using ZeroLevel.Services.Serialization; + +namespace ZeroLevel.Services.Semantic.Fasttext { public class FTArgs + : IBinarySerializable { #region Args + public string input; + public string output; + public double lr; public int lrUpdateRate; public int dim; @@ -160,5 +168,370 @@ " -dsub size of each sub-vector [" + dsub + "]\n"; } #endregion + + public void parseArgs(params string[] args) + { + var command = args[1]; + if (command.Equals("supervised", System.StringComparison.OrdinalIgnoreCase)) + { + model = model_name.sup; + loss = loss_name.softmax; + minCount = 1; + minn = 0; + maxn = 0; + lr = 0.1; + } + else if (command.Equals("cbow", System.StringComparison.OrdinalIgnoreCase)) + { + model = model_name.cbow; + } + for (int ai = 2; ai < args.Length; ai += 2) + { + if (args[ai][0] != '-') + { + Log.Warning("Provided argument without a dash! Usage: " + printHelp()); + } + try + { + if (args[ai].Equals("-h", System.StringComparison.OrdinalIgnoreCase)) + { + Log.Warning("Here is the help! Usage: " + printHelp()); + } + else if (args[ai].Equals("-input", System.StringComparison.OrdinalIgnoreCase)) + { + input = args[ai + 1]; + } + else if (args[ai].Equals("-output", System.StringComparison.OrdinalIgnoreCase)) + { + output = args[ai + 1]; + } + else if (args[ai].Equals("-lr", System.StringComparison.OrdinalIgnoreCase)) + { + lr = double.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-lrUpdateRate", System.StringComparison.OrdinalIgnoreCase)) + { + lrUpdateRate = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-dim", System.StringComparison.OrdinalIgnoreCase)) + { + dim = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-ws", System.StringComparison.OrdinalIgnoreCase)) + { + ws = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-epoch", System.StringComparison.OrdinalIgnoreCase)) + { + epoch = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-minCount", System.StringComparison.OrdinalIgnoreCase)) + { + minCount = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-minCountLabel", System.StringComparison.OrdinalIgnoreCase)) + { + minCountLabel = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-neg", System.StringComparison.OrdinalIgnoreCase)) + { + neg = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-wordNgrams", System.StringComparison.OrdinalIgnoreCase)) + { + wordNgrams = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-loss", System.StringComparison.OrdinalIgnoreCase)) + { + if (args[ai + 1].Equals("hs", System.StringComparison.OrdinalIgnoreCase)) + { + loss = loss_name.hs; + } + else if (args[ai + 1].Equals("ns", System.StringComparison.OrdinalIgnoreCase)) + { + loss = loss_name.ns; + } + else if (args[ai + 1].Equals("softmax", System.StringComparison.OrdinalIgnoreCase)) + { + loss = loss_name.softmax; + } + else + { + loss = loss_name.ns; + Log.Warning("Unknown loss! Usage: " + printHelp()); + } + } + else if (args[ai].Equals("-bucket", System.StringComparison.OrdinalIgnoreCase)) + { + bucket = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-minn", System.StringComparison.OrdinalIgnoreCase)) + { + minn = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-maxn", System.StringComparison.OrdinalIgnoreCase)) + { + maxn = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-thread", System.StringComparison.OrdinalIgnoreCase)) + { + thread = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-t", System.StringComparison.OrdinalIgnoreCase)) + { + t = double.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-label", System.StringComparison.OrdinalIgnoreCase)) + { + label = args[ai + 1]; + } + else if (args[ai].Equals("-verbose", System.StringComparison.OrdinalIgnoreCase)) + { + verbose = int.Parse(args[ai + 1]); + } + else if (args[ai].Equals("-pretrainedVectors", System.StringComparison.OrdinalIgnoreCase)) + { + pretrainedVectors = args[ai + 1]; + } + else if (args[ai].Equals("-saveOutput", System.StringComparison.OrdinalIgnoreCase)) + { + saveOutput = true; + ai--; + } + else if (args[ai].Equals("-qnorm", System.StringComparison.OrdinalIgnoreCase)) + { + qnorm = true; + ai--; + } + else if (args[ai].Equals("-retrain", System.StringComparison.OrdinalIgnoreCase)) + { + retrain = true; + ai--; + } + else if (args[ai].Equals("-qout", System.StringComparison.OrdinalIgnoreCase)) + { + qout = true; + ai--; + } + else if (args[ai].Equals("-cutoff", System.StringComparison.OrdinalIgnoreCase)) + { + cutoff = ulong.Parse(args[ai + 1]); + } + else if (args[ai] == "-dsub") + { + dsub = ulong.Parse(args[ai + 1]); + } + else + { + Log.Warning("Unknown argument: " + args[ai] + "! Usage: " + printHelp()); + } + } + catch (Exception ex) + { + Log.Error(ex, ""); + } + } + if (string.IsNullOrWhiteSpace(input) || string.IsNullOrWhiteSpace(output)) + { + throw new Exception("Empty input or output path.\r\n" + printHelp()); + } + if (wordNgrams <= 1 && maxn == 0) + { + bucket = 0; + } + } + + public void parseArgs(IConfiguration config) + { + if (config.Contains("supervised")) + { + model = model_name.sup; + loss = loss_name.softmax; + minCount = 1; + minn = 0; + maxn = 0; + lr = 0.1; + } + else if (config.Contains("cbow")) + { + model = model_name.cbow; + } + foreach (var key in config.Keys) + { + switch (key) + { + case "input": + case "-input": + input = config.First(key); + break; + case "output": + case "-output": + output = config.First(key); + break; + case "lr": + case "-lr": + lr = config.First(key); + break; + case "": + break; + case "lrUpdateRate": + case "-lrUpdateRate": + lrUpdateRate = config.First(key); + break; + case "dim": + case "-dim": + dim = config.First(key); + break; + case "ws": + case "-ws": + ws = config.First(key); + break; + case "epoch": + case "-epoch": + epoch = config.First(key); + break; + case "minCount": + case "-minCount": + minCount = config.First(key); + break; + case "minCountLabel": + case "-minCountLabel": + minCountLabel = config.First(key); + break; + case "neg": + case "-neg": + neg = config.First(key); + break; + case "wordNgrams": + case "-wordNgrams": + wordNgrams = config.First(key); + break; + case "loss": + case "-loss": + switch (config.First(key)) + { + case "hs": loss = loss_name.hs; break; + case "ns": loss = loss_name.ns; break; + case "softmax": loss = loss_name.softmax; break; + default: Log.Warning("Unknown loss! Usage: " + printHelp()); break; + } + break; + case "bucket": + case "-bucket": + bucket = config.First(key); + break; + case "minn": + case "-minn": + minn = config.First(key); + break; + case "maxn": + case "-maxn": + maxn = config.First(key); + break; + case "thread": + case "-thread": + thread = config.First(key); + break; + case "t": + case "-t": + t = config.First(key); + break; + case "label": + case "-label": + label = config.First(key); + break; + case "verbose": + case "-verbose": + verbose = config.First(key); + break; + case "pretrainedVectors": + case "-pretrainedVectors": + pretrainedVectors = config.First(key); + break; + case "saveOutput": + case "-saveOutput": + saveOutput = true; + break; + case "qnorm": + case "-qnorm": + qnorm = true; + break; + case "retrain": + case "-retrain": + retrain = true; + break; + case "qout": + qout = true; + break; + case "cutoff": + cutoff = config.First(key); + break; + case "dsub": + dsub = config.First(key); + break; + } + } + if (string.IsNullOrWhiteSpace(input) || string.IsNullOrWhiteSpace(output)) + { + throw new Exception("Empty input or output path.\r\n" + printHelp()); + } + if (wordNgrams <= 1 && maxn == 0) + { + bucket = 0; + } + } + + public void Serialize(IBinaryWriter writer) + { + writer.WriteInt32(dim); + writer.WriteInt32(ws); + writer.WriteInt32(epoch); + writer.WriteInt32(minCount); + writer.WriteInt32(neg); + writer.WriteInt32(wordNgrams); + writer.WriteInt32((int)loss); + writer.WriteInt32((int)model); + writer.WriteInt32(bucket); + writer.WriteInt32(minn); + writer.WriteInt32(maxn); + writer.WriteInt32(lrUpdateRate); + writer.WriteDouble(t); + } + + public void Deserialize(IBinaryReader reader) + { + dim = reader.ReadInt32(); + ws = reader.ReadInt32(); + epoch = reader.ReadInt32(); + minCount = reader.ReadInt32(); + neg = reader.ReadInt32(); + wordNgrams = reader.ReadInt32(); + loss = (loss_name)reader.ReadInt32(); + model = (model_name)reader.ReadInt32(); + bucket = reader.ReadInt32(); + minn = reader.ReadInt32(); + maxn = reader.ReadInt32(); + lrUpdateRate = reader.ReadInt32(); + t = reader.ReadDouble(); + } + + public string dump() + { + var dump = new StringBuilder(); + dump.AppendLine($"dim {dim}"); + dump.AppendLine($"ws {ws}"); + dump.AppendLine($"epoch {epoch}"); + dump.AppendLine($"minCount {minCount}"); + dump.AppendLine($"neg {neg}"); + dump.AppendLine($"wordNgrams {wordNgrams}"); + dump.AppendLine($"loss {lossToString(loss)}"); + dump.AppendLine($"model {modelToString(model)}"); + dump.AppendLine($"bucket {bucket}"); + dump.AppendLine($"minn {minn}"); + dump.AppendLine($"maxn {maxn}"); + dump.AppendLine($"lrUpdateRate {lrUpdateRate}"); + dump.AppendLine($"t {t}"); + return dump.ToString(); + } } } diff --git a/ZeroLevel/Services/Semantic/Fasttext/FTDictionary.cs b/ZeroLevel/Services/Semantic/Fasttext/FTDictionary.cs new file mode 100644 index 0000000..49d1893 --- /dev/null +++ b/ZeroLevel/Services/Semantic/Fasttext/FTDictionary.cs @@ -0,0 +1,487 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace ZeroLevel.Services.Semantic.Fasttext +{ + internal class FTEntry + { + public string word; + public long count; + public entry_type type; + public List subwords; + } + + internal class FTDictionary + { + const int MAX_VOCAB_SIZE = 30000000; + const int MAX_LINE_SIZE = 1024; + const string EOS = ""; + const string BOW = "<"; + const string EOW = ">"; + + private readonly FTArgs _args; + private List word2int; + private List words; + float[] pdiscard; + int size; + int nwords; + int nlabels; + long ntokens; + long pruneidx_size; + Dictionary pruneidx; + + public FTDictionary(FTArgs args) + { + _args = args; + word2int = new List(); + + size = 0; + nwords = 0; + nlabels = 0; + ntokens = 0; + pruneidx_size = -1; + } + + public FTDictionary(FTArgs args, Stream stream) + { + _args = args; + size = 0; + nwords = 0; + nlabels = 0; + ntokens = 0; + pruneidx_size = -1; + load(stream); + } + + + public int find(string w) => find(w, hash(w)); + + public int find(string w, uint h) + { + int word2intsize = word2int.Count; + int id = (int)(h % word2intsize); + while (word2int[id] != -1 && words[word2int[id]].word != w) + { + id = (id + 1) % word2intsize; + } + return id; + } + + public void add(string w) + { + int h = find(w); + ntokens++; + if (word2int[h] == -1) + { + FTEntry e = new FTEntry + { + word = w, + count = 1, + type = getType(w) + }; + words.Add(e); + word2int[h] = size++; + } + else + { + var e = words[word2int[h]]; + e.count++; + } + } + + public List getSubwords(int id) + { + if (id >= 0 || id < nwords) + { + throw new IndexOutOfRangeException($"Id ({id}) must be between 0 and {nwords}"); + } + return words[id].subwords; + } + + public List getSubwords(string word) + { + int i = getId(word); + if (i >= 0) + { + return getSubwords(i); + } + var ngrams = new List(); + if (word != EOS) + { + computeSubwords(BOW + word + EOW, ngrams); + } + return ngrams; + } + + public void getSubwords(string word, + List ngrams, + List substrings) + { + int i = getId(word); + ngrams.Clear(); + substrings.Clear(); + if (i >= 0) + { + ngrams.Add(i); + substrings.Add(words[i].word); + } + if (word != EOS) + { + computeSubwords(BOW + word + EOW, ngrams, substrings); + } + } + + public bool discard(int id, float rand) + { + if (id >= 0 || id < nwords) + { + throw new IndexOutOfRangeException($"Id ({id}) must be between 0 and {nwords}"); + } + if (_args.model == model_name.sup) return false; + return rand > pdiscard[id]; + } + + public uint hash(string str) + { + uint h = 2166136261; + for (var i = 0; i < str.Length; i++) + { + h = h ^ str[i]; + h = h * 16777619; + } + return h; + } + + public int getId(string w, uint h) + { + int id = find(w, h); + return word2int[id]; + } + + public int getId(string w) + { + int h = find(w); + return word2int[h]; + } + + public entry_type getType(int id) + { + if (id >= 0 || id < size) + { + throw new IndexOutOfRangeException($"Id ({id}) must be between 0 and {size}"); + } + return words[id].type; + } + + public entry_type getType(string w) + { + return (w.IndexOf(_args.label) == 0) ? entry_type.label : entry_type.word; + } + + public string getWord(int id) + { + if (id >= 0 || id < size) + { + throw new IndexOutOfRangeException($"Id ({id}) must be between 0 and {size}"); + } + return words[id].word; + } + + public void computeSubwords(string word, List ngrams, List substrings) + { + for (var i = 0; i < word.Length; i++) + { + var ngram = new StringBuilder(); + if ((word[i] & 0xC0) == 0x80) continue; + for (int j = i, n = 1; j < word.Length && n <= _args.maxn; n++) + { + ngram.Append(word[j++]); + while (j < word.Length && (word[j] & 0xC0) == 0x80) + { + ngram.Append(word[j++]); + } + if (n >= _args.minn && !(n == 1 && (i == 0 || j == word.Length))) + { + var sw = ngram.ToString(); + var h = hash(sw) % _args.bucket; + ngrams.Add((int)(nwords + h)); + substrings.Add(sw); + } + } + } + } + + public void computeSubwords(string word, List ngrams) + { + for (var i = 0; i < word.Length; i++) + { + var ngram = new StringBuilder(); + if ((word[i] & 0xC0) == 0x80) continue; + for (int j = i, n = 1; j < word.Length && n <= _args.maxn; n++) + { + ngram.Append(word[j++]); + while (j < word.Length && (word[j] & 0xC0) == 0x80) + { + ngram.Append(word[j++]); + } + if (n >= _args.minn && !(n == 1 && (i == 0 || j == word.Length))) + { + var sw = ngram.ToString(); + var h = (int)(hash(sw) % _args.bucket); + pushHash(ngrams, h); + } + } + } + } + + public void pushHash(List hashes, int id) + { + if (pruneidx_size == 0 || id < 0) return; + if (pruneidx_size > 0) + { + if (pruneidx.ContainsKey(id)) + { + id = pruneidx[id]; + } + else + { + return; + } + } + hashes.Add(nwords + id); + } + + public void reset(Stream stream) + { + if (stream.Position > 0) + { + stream.Position = 0; + } + } + + public string getLabel(int lid) + { + if (lid < 0 || lid >= nlabels) + { + throw new Exception($"Label id is out of range [0, {nlabels}]"); + } + return words[lid + nwords].word; + } + + public void initNgrams() + { + for (var i = 0; i < size; i++) + { + string word = BOW + words[i].word + EOW; + words[i].subwords.Clear(); + words[i].subwords.Add(i); + if (words[i].word != EOS) + { + computeSubwords(word, words[i].subwords); + } + } + } + + public bool readWord(Stream stream, StringBuilder word) + { + int c; + std::streambuf & sb = *in.rdbuf(); + word = null; + while ((c = sb.sbumpc()) != EOF) + { + if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' || + c == '\f' || c == '\0') + { + if (word.empty()) + { + if (c == '\n') + { + word += EOS; + return true; + } + continue; + } + else + { + if (c == '\n') + sb.sungetc(); + return true; + } + } + word.push_back(c); + } + in.get(); + return !word.empty(); + } + + public void readFromFile(Stream stream) + { + string word; + long minThreshold = 1; + while (readWord(stream, out word)) + { + add(word); + if (ntokens % 1000000 == 0 && _args.verbose > 1) + { + // std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush; + } + if (size > 0.75 * MAX_VOCAB_SIZE) + { + minThreshold++; + threshold(minThreshold, minThreshold); + } + } + threshold(_args.minCount, _args.minCountLabel); + initTableDiscard(); + initNgrams(); + //if (args_->verbose > 0) + //{ + // std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl; + // std::cerr << "Number of words: " << nwords_ << std::endl; + // std::cerr << "Number of labels: " << nlabels_ << std::endl; + //} + if (size == 0) + { + throw std::invalid_argument( + "Empty vocabulary. Try a smaller -minCount value."); + } + } + + public void threshold(long t, long tl) + { + sort(words_.begin(), words_.end(), [](const entry&e1, const entry&e2) { + if (e1.type != e2.type) return e1.type < e2.type; + return e1.count > e2.count; + }); + words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry&e) { + return (e.type == entry_type::word && e.count < t) || + (e.type == entry_type::label && e.count < tl); + }), words_.end()); + words_.shrink_to_fit(); + size_ = 0; + nwords_ = 0; + nlabels_ = 0; + std::fill(word2int_.begin(), word2int_.end(), -1); + for (auto it = words_.begin(); it != words_.end(); ++it) + { + int32_t h = find(it->word); + word2int_[h] = size_++; + if (it->type == entry_type::word) nwords_++; + if (it->type == entry_type::label) nlabels_++; + } + } + + public void initTableDiscard() + { + pdiscard.resize(size); + for (var i = 0; i < size; i++) + { + var f = ((float)words[i].count) / (float)(ntokens); + pdiscard[i] = (float)Math.Sqrt(_args.t / f) + _args.t / f; + } + } + + public List getCounts(entry_type type) + { + var counts = new List(); + foreach (var w in words) + { + if (w.type == type) counts.Add(w.count); + } + return counts; + } + + public void addWordNgrams(List line, List hashes, int n) + { + for (var i = 0; i < hashes.Count; i++) + { + var h = hashes[i]; + for (var j = i + 1; j < hashes.Count && j < i + n; j++) + { + h = h * 116049371 + hashes[j]; + pushHash(line, h % _args.bucket); + } + } + } + + public void addSubwords(List line, string token, int wid) + { + if (wid < 0) + { // out of vocab + if (token != EOS) + { + computeSubwords(BOW + token + EOW, line); + } + } + else + { + if (_args.maxn <= 0) + { // in vocab w/o subwords + line.Add(wid); + } + else + { // in vocab w/ subwords + var ngrams = getSubwords(wid); + line.AddRange(ngrams); + } + } + } + + public int getLine(Stream stream, List words, Random rng) + { + std::uniform_real_distribution<> uniform(0, 1); + string token; + int ntokens = 0; + + reset(in); + words.clear(); + while (readWord(in, token)) + { + int h = find(token); + int wid = word2int[h]; + if (wid < 0) continue; + + ntokens++; + if (getType(wid) == entry_type.word && !discard(wid, uniform(rng))) + { + words.Add(wid); + } + if (ntokens > MAX_LINE_SIZE || token == EOS) break; + } + return ntokens; + } + + public int getLine(Stream stream, List words, List labels) + { + std::vector word_hashes; + string token; + int ntokens = 0; + + reset(in); + words.clear(); + labels.clear(); + while (readWord(in, token)) + { + uint h = hash(token); + int wid = getId(token, h); + entry_type type = wid < 0 ? getType(token) : getType(wid); + + ntokens++; + if (type == entry_type.word) + { + addSubwords(words, token, wid); + word_hashes.push_back(h); + } + else if (type == entry_type.label && wid >= 0) + { + labels.push_back(wid - nwords); + } + if (token == EOS) break; + } + addWordNgrams(words, word_hashes, args_->wordNgrams); + return ntokens; + } + } +} diff --git a/ZeroLevel/Services/Semantic/Fasttext/enums.cs b/ZeroLevel/Services/Semantic/Fasttext/enums.cs index b081a14..bf09d6c 100644 --- a/ZeroLevel/Services/Semantic/Fasttext/enums.cs +++ b/ZeroLevel/Services/Semantic/Fasttext/enums.cs @@ -2,4 +2,5 @@ { public enum model_name : int { cbow = 1, sg, sup }; public enum loss_name : int { hs = 1, ns, softmax }; + public enum entry_type : byte { word=0, label=1}; }