From 5b80521c40cb3522b75064cc24c0575ccea2677b Mon Sep 17 00:00:00 2001 From: gleb Date: Mon, 20 Mar 2023 08:59:40 +0400 Subject: [PATCH 01/10] draft wordpiece --- LICENSE | 2 +- MANIFEST.in | 5 +- README.md | 88 +++--- benchmark.md => benchmark_bpe.md | 12 +- benchmark_wordpiece.md | 38 +++ setup.py | 2 +- tests/speed_test/Dockerfile | 18 +- tests/speed_test/README.md | 37 ++- tests/speed_test/{speed_test.py => bpe.py} | 8 +- tests/speed_test/wordpiece.py | 251 ++++++++++++++++++ tests/unit_tests/{ => bpe}/README.md | 5 +- tests/unit_tests/{ => bpe}/stress_test.cpp | 0 tests/unit_tests/{ => bpe}/stress_test.h | 2 +- tests/unit_tests/{ => bpe}/test_cli.py | 0 tests/unit_tests/{ => bpe}/test_manual.py | 0 tests/unit_tests/{ => bpe}/test_python_api.py | 0 tests/unit_tests/{ => bpe}/test_stress.py | 0 .../unit_tests/{ => bpe}/utils_for_testing.py | 0 tests/unit_tests/wordpiece/README.md | 6 + tests/unit_tests/wordpiece/test_cli.py | 0 tests/unit_tests/wordpiece/test_manual.py | 0 tests/unit_tests/wordpiece/test_python_api.py | 0 youtokentome/cpp/bpe.cpp | 90 +------ youtokentome/cpp/bpe.h | 6 +- .../third_party/{ => flat_hash_map}/LICENSE | 0 .../{ => flat_hash_map}/flat_hash_map.h | 0 .../cpp/third_party/thread_pool/LICENSE | 21 ++ .../cpp/third_party/thread_pool/thread_pool.h | 91 +++++++ youtokentome/cpp/utf8.cpp | 43 ++- youtokentome/cpp/utf8.h | 71 ++++- youtokentome/cpp/utils.cpp | 28 +- youtokentome/cpp/utils.h | 101 ++++++- youtokentome/cpp/wordpiece.cpp | 234 ++++++++++++++++ youtokentome/cpp/wordpiece.h | 17 ++ 34 files changed, 978 insertions(+), 198 deletions(-) rename benchmark.md => benchmark_bpe.md (92%) create mode 100644 benchmark_wordpiece.md rename tests/speed_test/{speed_test.py => bpe.py} (98%) create mode 100644 tests/speed_test/wordpiece.py rename tests/unit_tests/{ => bpe}/README.md (64%) rename tests/unit_tests/{ => bpe}/stress_test.cpp (100%) rename tests/unit_tests/{ => bpe}/stress_test.h (91%) rename tests/unit_tests/{ => bpe}/test_cli.py (100%) rename tests/unit_tests/{ => bpe}/test_manual.py (100%) rename tests/unit_tests/{ => bpe}/test_python_api.py (100%) rename tests/unit_tests/{ => bpe}/test_stress.py (100%) rename tests/unit_tests/{ => bpe}/utils_for_testing.py (100%) create mode 100644 tests/unit_tests/wordpiece/README.md create mode 100644 tests/unit_tests/wordpiece/test_cli.py create mode 100644 tests/unit_tests/wordpiece/test_manual.py create mode 100644 tests/unit_tests/wordpiece/test_python_api.py rename youtokentome/cpp/third_party/{ => flat_hash_map}/LICENSE (100%) rename youtokentome/cpp/third_party/{ => flat_hash_map}/flat_hash_map.h (100%) create mode 100644 youtokentome/cpp/third_party/thread_pool/LICENSE create mode 100644 youtokentome/cpp/third_party/thread_pool/thread_pool.h create mode 100644 youtokentome/cpp/wordpiece.cpp create mode 100644 youtokentome/cpp/wordpiece.h diff --git a/LICENSE b/LICENSE index 52a300b..40b88d8 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2019 VK.com +Copyright (c) 2019-2023 VK.com Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/MANIFEST.in b/MANIFEST.in index 4ce0eae..92d8a80 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,12 +2,9 @@ include youtokentome/cpp/utils.h include youtokentome/cpp/bpe.h include youtokentome/cpp/utf8.h include youtokentome/cpp/yttm.pyx -include youtokentome/cpp/third_party/flat_hash_map.h +include youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h include youtokentome/cpp/third_party/LICENSE include LICENSE include README.md include requirements.txt include yttm_cli.py - - - diff --git a/README.md b/README.md index 00b123e..a4ec64f 100644 --- a/README.md +++ b/README.md @@ -6,20 +6,21 @@ # YouTokenToMe -YouTokenToMe is an unsupervised text tokenizer focused on computational efficiency. It currently implements fast Byte Pair Encoding (BPE) [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162)]. -Our implementation is much faster in training and tokenization than [Hugging Face](https://github.com/huggingface/tokenizers), [fastBPE](https://github.com/glample/fastBPE) - and [SentencePiece](https://github.com/google/sentencepiece). In some test cases, it is 60 times faster. - Check out our [benchmark](benchmark.md) results. +YouTokenToMe is an unsupervised text tokenizer focused on computational efficiency. It currently contains the fastest implementations of: +- Byte Pair Encoding (BPE) [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162)], [benchmark results](benchmark.md); +- WordPiece [[Song et al.](https://arxiv.org/abs/2012.15524)], [benchmark results](benchmark.md). Key advantages: * Multithreading for training and tokenization -* The algorithm has `O(N)` complexity, where `N` is the length of training data * Highly efficient implementation in C++ * Python wrapper and command-line interface -Extra features: -* BPE-dropout (as described in [Provilkov et al, 2019](https://arxiv.org/abs/1910.13267)) +## BPE implementation + +Algorighm properties: +* Time complexity is `O(N)`, where `N` is the length of training data +* Supports BPE-dropout (as described in [Provilkov et al, 2019](https://arxiv.org/abs/1910.13267)) As well as in the algorithm from the original paper, ours does not consider tokens that cross word boundaries. Just like in [SentencePiece](https://github.com/google/sentencepiece), all space symbols were replaced by meta symbol "▁" (U+2581). It allows sequences of tokens to be converted back to text and for word boundaries to be restored. @@ -28,15 +29,21 @@ For example, the phrase ```Blazingly fast tokenization!``` can be tokenized into `['▁Bl', 'az', 'ingly', '▁fast', '▁token', 'ization', '!']` +## WordPiece implementation + +Algorighm properties: +* Currently supports tokenizer only, but not training +* Time complexity is `O(NM)`, where `N` is the length of tokenized data and `M` is the max length of word in vocabulary + ## Installation ```bash pip install youtokentome ``` + ## Python interface -### Example -Let's start with a self-contained example. +### BPE Example ```python import random @@ -68,10 +75,31 @@ print(bpe.encode([test_text], output_type=yttm.OutputType.ID)) print(bpe.encode([test_text], output_type=yttm.OutputType.SUBWORD)) ``` +### WordPiece Example + +TODO + +### BPE Methods +Class `youtokentome.BPE` has the following methods: + +#### constructor + +```python +youtokentome.BPE(model, n_threads=-1) +``` + +Class constructor. Loads the trained model. + +* `model`: string, path to the trained model +* `n_threads`: int, number of parallel threads used to run. + If equal to -1, then the maximum number of threads available will be used. +   -### Training model + +#### train + ```python -youtokentome.BPE.train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3) +train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3) ``` Trains BPE model and saves to file. @@ -92,22 +120,6 @@ Trains BPE model and saves to file.   -### Model loading - -```python -youtokentome.BPE(model, n_threads=-1) -``` - -Class constructor. Loads the trained model. - -* `model`: string, path to the trained model -* `n_threads`: int, number of parallel threads used to run. - If equal to -1, then the maximum number of threads available will be used. - -  - -### Methods -Class `youtokentome.BPE` has the following methods: #### encode ```python encode(self, sentences, output_type=yttm.OutputType.ID, bos=False, eos=False, reverse=False, dropout_prob=0) @@ -185,16 +197,12 @@ Convert each id to subword and concatenate with space symbol. **Returns:** List of strings. - -## Command line interface -### Example +### WordPiece methods -```bash -$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 -$ yttm encode --model OUTPUT_MODEL_FILE --output_type subword < TEST_DATA_FILE > ENCODED_DATA -``` +TODO +## Command line interface ### Supported commands @@ -241,7 +249,7 @@ Options: Apply BPE encoding for a corpus of sentences. Use `stdin` for input and `stdout` for output. By default, encoding works in parallel using `n_threads` threads. Number of threads is limited by -8 (see [benchmark](benchmark.md#number-of-threads)). +8 (see [benchmark](benchmark_bpe.md#number-of-threads)). With the `--stream` option, `--n_threads` will be ignored and all sentences will be processed one by one. Each sentence will be tokenized and written to the `stdout` before the next sentence is read. @@ -296,9 +304,11 @@ Options: --help Show this message and exit. ``` +### Examples +TODO: wordpiece - - - - +```bash +$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 +$ yttm encode --model OUTPUT_MODEL_FILE --output_type subword < TEST_DATA_FILE > ENCODED_DATA +``` \ No newline at end of file diff --git a/benchmark.md b/benchmark_bpe.md similarity index 92% rename from benchmark.md rename to benchmark_bpe.md index 36c43b2..bc1805f 100644 --- a/benchmark.md +++ b/benchmark_bpe.md @@ -1,7 +1,11 @@ -## Speed tests +## BPE Speed tests -`YouTokenToMe` will be compared with [Hugging Face](https://github.com/huggingface/tokenizers), [SentencePiece](https://github.com/google/sentencepiece/) - and [fastBPE](https://github.com/glample/fastBPE). These three algorithms are considered to be fast. +`YouTokenToMe` will be compared with: +* [Hugging Face](https://github.com/huggingface/tokenizers) +* [SentencePiece](https://github.com/google/sentencepiece/) +* [fastBPE](https://github.com/glample/fastBPE) + +These algorithms are considered to be fast. Data from [Wikipedia](https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/) was used to evaluate algorithm speed. In a similar way to `enwik8` and `enwik9`, the experiments were run on first `10^8` and `10^9` bytes of datasets for English, Russian, Chinese and Japanese. @@ -11,7 +15,7 @@ In this benchmark, `YouTokenToMe` used 4 threads for training and tokenization. doesn't support multithreading for **BPE** at all. `fastBPE` doesn't support multithreading for training. For tokenization, it also used 4 threads. -Source code for benchmark can be found [here](tests/speed_test/speed_test.py). +Source code for benchmark can be found [here](tests/speed_test/bpe.py). The results of the experiments are below. The time is measured in seconds. All experiments were run on the following machine: diff --git a/benchmark_wordpiece.md b/benchmark_wordpiece.md new file mode 100644 index 0000000..aaf4e27 --- /dev/null +++ b/benchmark_wordpiece.md @@ -0,0 +1,38 @@ +## WordPiece Speed tests + +`YouTokenToMe` will be compared with: +* [Hugging Face](https://github.com/huggingface/tokenizers) +* [Keras](https://github.com/keras-team/keras-nlp) +* [Tensorflow](https://github.com/tensorflow/text) +* [Torch](https://github.com/pytorch/text) + +These algorithms are considered to be fast. + +Data from [Wikipedia](https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/) was used to evaluate algorithm speed. In a similar way to `enwik8` and `enwik9`, the experiments were run on first `10^8` and `10^9` bytes of datasets for English, Russian, Chinese and Japanese. + +Used vocabulary: [bert-base-cased](https://huggingface.co/bert-base-cased). + +In this benchmark, `YouTokenToMe` used 4 threads for training and tokenization. + +Source code for benchmark can be found [here](tests/speed_test/wordpiece.py). +The results of the experiments are below. The time is measured in seconds. + +All experiments were run on the following machine: TODO + +### Tokenization 100MB +TODO: TABLE + +### Tokenization 1GB +TODO: TABLE + +`YouTokenToMe` performed really well in this benchmark. This is especially noticeable for languages with large alphabets. + +## Number of threads + +The table below shows the dependence of performance on the number of threads for `YouTokenToMe`. + +### Tokenization 1GB +TODO: TABLE + + +TODO: CONCLUSION ON THREADS diff --git a/setup.py b/setup.py index 867fc0f..b603d23 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ "youtokentome/cpp/utils.cpp", "youtokentome/cpp/utf8.cpp", ], - extra_compile_args=["-std=c++11", "-pthread", "-O3"], + extra_compile_args=["-std=c++17", "-pthread", "-O3"], language="c++", ) ] diff --git a/tests/speed_test/Dockerfile b/tests/speed_test/Dockerfile index 2dc36fa..24b385d 100644 --- a/tests/speed_test/Dockerfile +++ b/tests/speed_test/Dockerfile @@ -8,8 +8,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ cmake \ make \ g++ \ - wget && \ - pip3 install tabulate youtokentome tokenizers + wget \ + bzip2 \ + perl && \ + pip3 install -r requirements.txt && \ + pip3 install youtokentome WORKDIR /repos @@ -26,8 +29,13 @@ RUN g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast WORKDIR /workspace -COPY ./speed_test.py ./speed_test.py RUN cp /repos/fastBPE/fast /workspace/fastBPE +RUN wget -O bert-base-cased.txt https://huggingface.co/bert-base-cased/resolve/main/vocab.txt -# CMD ["python", "speed_test.py", "--langs", "en", "ru", "zh", "ja", "--corpus_size", "100", "--vocab_size", "30000"] -CMD ["python", "speed_test.py", "--langs", "ru", "--corpus_size", "10", "--vocab_size", "30000"] +COPY ./bpe.py ./bpe.py +COPY ./wordpiece.py ./wordpiece.py + +# use comma to separate langs, e.g.: "--langs", "en", "ru", "zh", "ja" +CMD ["python", "bpe.py", "--langs", "ru", "--corpus_size", "10", "--vocab_size", "30000"] + +CMD ["python", "bpe.py", "--langs", "ru", "--corpus_size", "10", "--vocab", "bert-base-cased.txt"] \ No newline at end of file diff --git a/tests/speed_test/README.md b/tests/speed_test/README.md index 3283c13..9460755 100644 --- a/tests/speed_test/README.md +++ b/tests/speed_test/README.md @@ -1,25 +1,38 @@ # Running benchmark -* Install [YouTokenToMe](https://github.com/vkcom/youtokentome) -* Install [SentencePiece](https://github.com/google/sentencepiece) -* Install [Hugging Face Tokenizer](https://github.com/huggingface/tokenizers) -* Compile [fastBPE](https://github.com/glample/fastBPE) and specify path to binary file in variable - `PATH_TO_FASTBPE` in `speed_test.py` -* `python speed_test.py` - - **Warning!** This test requires about **20 GBs** of free space on your disk and can take **about one hour** for running. +**Warning!** This test requires about **20 GBs** of free space on your disk and can take **about one hour** for running. It uses Wikipedia monolingual corpora for training and tokenization. [Here](https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/) you can find more details about the data. - -## Docker -Alternatively benchmark can be run using Docker. +## Recommended approach + +Benchmark can be run using Docker. Substitute `PATH_TO_DOWNLOADED_DATA` with absolute path to the directory where wiki dumps will be downloaded. -``` +```bash cd tests/speed_test docker build -t yttm/speed_test . docker run --rm -v PATH_TO_DOWNLOADED_DATA:/workspace/data -it yttm/speed_test:latest ``` + +## Alternative approach + +## BPE benchmark + +* Install [YouTokenToMe](https://github.com/vkcom/youtokentome) +* Install [Hugging Face Tokenizer](https://github.com/huggingface/tokenizers) +* Install [SentencePiece](https://github.com/google/sentencepiece) +* Compile [fastBPE](https://github.com/glample/fastBPE) and specify path to binary file in variable + `PATH_TO_FASTBPE` in `bpe.py` +* `python bpe.py` + +## WordPiece benchmark + +* Install [YouTokenToMe](https://github.com/vkcom/youtokentome) +* Install [Hugging Face Tokenizer](https://github.com/huggingface/tokenizers) +* Install [Keras](https://github.com/keras-team/keras-nlp) +* Install [Tensorflow](https://github.com/tensorflow/text) +* Install [Torch](https://github.com/pytorch/text) +* `python wordpiece.py` diff --git a/tests/speed_test/speed_test.py b/tests/speed_test/bpe.py similarity index 98% rename from tests/speed_test/speed_test.py rename to tests/speed_test/bpe.py index 56e4d20..68adde8 100644 --- a/tests/speed_test/speed_test.py +++ b/tests/speed_test/bpe.py @@ -15,12 +15,12 @@ YOU_TOKEN_TO_ME = "YouTokenToMe" SENTENCE_PIECE = "SentencePiece" FAST_BPE = "fastBPE" -HUGGING_FACE_BPE = "Hugging_Face_BPE" +HUGGING_FACE= "Hugging_Face" PATH_TO_FASTBPE = "./fastBPE" -class HuggingfaceInterface: +class HuggingFaceInterface: def train_from_file(self, train_file, vocab_size, model_file, _): tokenizer = HuggingFaceBPETokenizer(HuggingFaceBPEModel(unk_token="[UNK]")) trainer = HuggingFaceBPETrainer(special_tokens=["[UNK]", "[PAD]"], vocab_size=vocab_size) @@ -90,8 +90,8 @@ def get_bpe(impl_name): return SentencePieceInterface() if impl_name == FAST_BPE: return FastBPEInterface() - if impl_name == HUGGING_FACE_BPE: - return HuggingfaceInterface() + if impl_name == HUGGING_FACE: + return HuggingFaceInterface() assert False diff --git a/tests/speed_test/wordpiece.py b/tests/speed_test/wordpiece.py new file mode 100644 index 0000000..0502cfd --- /dev/null +++ b/tests/speed_test/wordpiece.py @@ -0,0 +1,251 @@ +import argparse +import os +from pathlib import Path +from time import time + +import keras_nlp +import tensorflow +from tabulate import tabulate +from tensorflow_text import BertTokenizer as TensorflowBertTokenizer +from tokenizers import BertWordPieceTokenizer as HuggingFaceBertTokenizer +from torchtext.transforms import BERTTokenizer as TorchBertTokenizer + + +YOU_TOKEN_TO_ME = "YouTokenToMe" +HUGGING_FACE = 'Hugging_Face' +KERAS = 'Keras' +TENSORFLOW = 'Tensorflow' +TORCH = 'Torch' + +ALGORITHMS = [YOU_TOKEN_TO_ME, HUGGING_FACE, KERAS, TENSORFLOW, TORCH] +LOWER_CASE = False + + +def collect_to_file(out_file, ids): + if out_file is not None: + with open(out_file, 'w') as f: + for i in ids: + f.write(f'{i} ') + +def run_tensorflow(text_file, vocab_file, n_threads, out_file): + text = "" + with open(text_file, 'r') as f: + text = f.read() + vocab_list = [] + with open(vocab_file, 'r') as f: + for word in f: + vocab_list.append(word) + lookup_table = tensorflow.lookup.StaticVocabularyTable( + tensorflow.lookup.KeyValueTensorInitializer( + keys=vocab_list, + key_dtype=tensorflow.string, + values=tensorflow.range( + tensorflow.size(vocab_list, out_type=tensorflow.int64), dtype=tensorflow.int64), + value_dtype=tensorflow.int64 + ), + num_oov_buckets=1 + ) + tokenizer = TensorflowBertTokenizer(lookup_table, token_out_type=tensorflow.int64, lower_case=LOWER_CASE) + ids = tokenizer.tokenize(text).numpy().tolist() + assert len(ids) > 0 + collect_to_file(out_file, ids) + return len(ids) + + +def run_hugging_face(text_file, vocab_file, n_threads, out_file): + with open(text_file, 'r') as f: + text = f.read() + tokenizer = HuggingFaceBertTokenizer(vocab_file, lowercase=LOWER_CASE) + ids = tokenizer.encode(text).ids + assert len(ids) > 0 + collect_to_file(out_file, ids) + return len(ids) + + +def run_torch(text_file, vocab_file, n_threads, out_file): + with open(text_file, 'r') as f: + text = f.read() + tokenizer = TorchBertTokenizer(vocab_file, do_lower_case=LOWER_CASE) + ids = tokenizer(text) + assert len(ids) > 0 + collect_to_file(out_file, ids) + return len(ids) + + +def run_keras(text_file, vocab_file, n_threads, out_file): + with open(text_file, 'r') as f: + text = f.read() + tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(vocabulary=vocab_file, lowercase=LOWER_CASE) + ids = tokenizer.tokenize(text).numpy().tolist() + assert len(ids) > 0 + collect_to_file(out_file, ids) + return len(ids) + + +def run_you_token_to_me(text_file, vocab_file, n_threads, out_file): + assert(LOWER_CASE == False) + out_file = out_file if out_file is not None else "" + rc = 0 # TODO + assert rc == 0 + return rc + + +def get_wordpiece(impl_name): + if impl_name == YOU_TOKEN_TO_ME: + return run_you_token_to_me + elif impl_name == HUGGING_FACE: + return run_hugging_face + elif impl_name == KERAS: + return run_keras + elif impl_name == TENSORFLOW: + return run_tensorflow + elif impl_name == TORCH: + return run_torch + assert False + + +def download_xml2txt(): + if not Path("xml2txt.pl").exists(): + print("downloading xml2txt.pl ...") + os.system("wget https://www.dropbox.com/s/p3ta9spzfviovk0/xml2txt.pl") + + +def prepare_data(zip_path, size_mb): + expected_extension = ".xml.bz2" + assert zip_path.endswith(expected_extension) + base_path = Path(zip_path).parent + unzip_path = base_path / "wiki.xml" + full_text_path = base_path / "wiki.txt" + cutted_text_path = base_path / f"wiki_{size_mb}MB.txt" + if not Path(unzip_path).exists(): + print(f"unziping file {zip_path} ...") + assert os.system(f"bzip2 -kdc {zip_path} > {unzip_path}") == 0 + if not Path(full_text_path).exists(): + print(f"converting xml to text {unzip_path} ...") + download_xml2txt() + preprocess_command = f"perl xml2txt.pl " + preprocess_command += f" -nomath -notables " + preprocess_command += f" {unzip_path} {full_text_path}" + assert os.system(preprocess_command) == 0 + if not Path(cutted_text_path).exists(): + byte_processed = 0 + with open(cutted_text_path, "w") as fout: + with open(full_text_path, "r") as fin: + while byte_processed < size_mb * 1_000_000: + s = fin.readline() + byte_processed += len(s.encode()) + fout.write(s) + return cutted_text_path + + +def check_inference_file(algorithm, text_file, vocab_file, n_threads, out_file): + wordpiece = get_wordpiece(algorithm) + start_time = time() + res = wordpiece(text_file, vocab_file, n_threads, out_file) + elapsed = time() - start_time + print(f"Runner returned: {res}") + return elapsed + + +def speed_test(text_file: str, vocab_file: str, algorithms, n_threads: int, collect: bool,): + result = {} + for algorithm in algorithms: + print(f'Running {algorithm}') + out_file = f"result_{algorithm}.txt" if collect else None + time_infer = check_inference_file(algorithm, text_file, vocab_file, n_threads, out_file) + print(f'{algorithm} finished in {time_infer:.1f} sec') + result[algorithm] = time_infer + + return result + + +def print_results(cfg, result_name, corpuses, algorithms): + result_table = [ + ["#" for _ in range(len(corpuses) + 1)] for _ in range(len(algorithms)) + ] + table_header = ["#"] + [lang for lang in corpuses] + rev_lang = {lang: i for i, lang in enumerate(table_header)} + rev_algo = {algo: i for i, algo in enumerate(algorithms)} + for i, algo_name in enumerate(algorithms): + result_table[i][0] = algo_name + + for lang, res in cfg.items(): + best = min(res.values()) + for algo in res: + j = rev_lang[lang] + i = rev_algo[algo] + multiplier_str = f"{res[algo]/best:.1f}".rstrip('0').rstrip('.') + result_table[i][j] = f"{res[algo]:.1f} (x{multiplier_str})" + + table_header[0] = result_name + column_align = ["left"] + ["center" for _ in corpuses] + print(tabulate(result_table, table_header, tablefmt="github", colalign=column_align)) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--vocab", type=str, required=True, help="path to vocab file" + ) + parser.add_argument("--n_threads", type=int, default=8) + parser.add_argument( + "--corpus_size", type=int, default=10, help="Size of testing corpus in MB" + ) + parser.add_argument( + "--langs", + type=str, + nargs="+", + help="list of languages for speed test", + default="en", + ) + parser.add_argument("--collect", action="store_true") + + return parser.parse_args() + + +def main(args): + langs = args.langs if isinstance(args.langs, list) else [args.langs] + # Hugging Face - limit number of processes + os.environ["RAYON_RS_NUM_CPUS"] = str(args.n_threads) + + short_to_long_names = { + "en": "English", + "ru": "Russian", + "ja": "Japanese", + "zh": "Chinese", + } + + # For adding more languages check out this page https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/ + all_links = { + "English": "https://www.dropbox.com/s/cnrhd11zdtc1pic/enwiki-20181001-corpus.xml.bz2?dl=1", + "Russian": "https://www.dropbox.com/s/lpfmyrl7nxn5ugg/ruwiki-20181001-corpus.xml.bz2?dl=1", + "Japanese": "https://www.dropbox.com/s/wf496hlu512z9kc/jawiki-20140807-corpus.xml.bz2?dl=1", + "Chinese": "https://www.dropbox.com/s/czhr6s5jwaljeue/zhwiki-20140804-corpus.xml.bz2?dl=1", + } + links = { + short_to_long_names[lang]: all_links[short_to_long_names[lang]] + for lang in langs + } + + corpuses = {} + Path("data").mkdir(exist_ok=True) + for lang, link in links.items(): + Path(f"data/{lang}").mkdir(exist_ok=True) + zip_file = f"data/{lang}/wiki.xml.bz2" + if not Path(zip_file).exists(): + os.system(f"wget -O {zip_file} {link}") + corpuses[lang] = prepare_data(zip_file, args.corpus_size) + + global_tokenization = {} + + for lang, corpus_path in corpuses.items(): + tokenization_stat = speed_test(corpus_path, args.vocab, ALGORITHMS, args.n_threads, args.collect) + global_tokenization[lang] = tokenization_stat + + print_results(global_tokenization, f"Tokenization {args.corpus_size}MB", corpuses, ALGORITHMS) + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/tests/unit_tests/README.md b/tests/unit_tests/bpe/README.md similarity index 64% rename from tests/unit_tests/README.md rename to tests/unit_tests/bpe/README.md index f2ef135..500bf35 100644 --- a/tests/unit_tests/README.md +++ b/tests/unit_tests/bpe/README.md @@ -3,7 +3,4 @@ For tests execution simply run: pip install pytest pytest ``` -Testing may take several minutes. - - - +Testing may take several minutes. \ No newline at end of file diff --git a/tests/unit_tests/stress_test.cpp b/tests/unit_tests/bpe/stress_test.cpp similarity index 100% rename from tests/unit_tests/stress_test.cpp rename to tests/unit_tests/bpe/stress_test.cpp diff --git a/tests/unit_tests/stress_test.h b/tests/unit_tests/bpe/stress_test.h similarity index 91% rename from tests/unit_tests/stress_test.h rename to tests/unit_tests/bpe/stress_test.h index 93732d7..67e5058 100644 --- a/tests/unit_tests/stress_test.h +++ b/tests/unit_tests/bpe/stress_test.h @@ -1,6 +1,6 @@ #pragma once -#include "../../youtokentome/cpp/third_party/flat_hash_map.h" +#include "../../youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h" #include "../../youtokentome/cpp/utils.h" namespace vkcom { diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/bpe/test_cli.py similarity index 100% rename from tests/unit_tests/test_cli.py rename to tests/unit_tests/bpe/test_cli.py diff --git a/tests/unit_tests/test_manual.py b/tests/unit_tests/bpe/test_manual.py similarity index 100% rename from tests/unit_tests/test_manual.py rename to tests/unit_tests/bpe/test_manual.py diff --git a/tests/unit_tests/test_python_api.py b/tests/unit_tests/bpe/test_python_api.py similarity index 100% rename from tests/unit_tests/test_python_api.py rename to tests/unit_tests/bpe/test_python_api.py diff --git a/tests/unit_tests/test_stress.py b/tests/unit_tests/bpe/test_stress.py similarity index 100% rename from tests/unit_tests/test_stress.py rename to tests/unit_tests/bpe/test_stress.py diff --git a/tests/unit_tests/utils_for_testing.py b/tests/unit_tests/bpe/utils_for_testing.py similarity index 100% rename from tests/unit_tests/utils_for_testing.py rename to tests/unit_tests/bpe/utils_for_testing.py diff --git a/tests/unit_tests/wordpiece/README.md b/tests/unit_tests/wordpiece/README.md new file mode 100644 index 0000000..500bf35 --- /dev/null +++ b/tests/unit_tests/wordpiece/README.md @@ -0,0 +1,6 @@ +For tests execution simply run: +``` +pip install pytest +pytest +``` +Testing may take several minutes. \ No newline at end of file diff --git a/tests/unit_tests/wordpiece/test_cli.py b/tests/unit_tests/wordpiece/test_cli.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/wordpiece/test_manual.py b/tests/unit_tests/wordpiece/test_manual.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/wordpiece/test_python_api.py b/tests/unit_tests/wordpiece/test_python_api.py new file mode 100644 index 0000000..e69de29 diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index c28ee8a..8348f97 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -19,7 +19,7 @@ #include #include -#include "third_party/flat_hash_map.h" +#include "third_party/flat_hash_map/flat_hash_map.h" #include "utf8.h" #include "utils.h" @@ -28,43 +28,8 @@ using std::string; using std::vector; using std::unordered_set; -struct VectorSegment { - constexpr static uint64_t MOD = 2032191299; - constexpr static uint64_t P = 726328703; - - const char* begin; - const char* end; - uint64_t hash; - - VectorSegment(const char* begin, const char* end): begin(begin), end(end) { - hash = 0; - for (auto it = begin; it != end; it++) { - hash = (hash * P + (unsigned char)(*it)) % MOD; - } - } - - bool operator==(const VectorSegment &other) const { - if (other.hash != hash || end - begin != other.end - other.begin) { - return false; - } - for (auto it = begin, other_it = other.begin; it != end; it++, other_it++) { - if (*it != *other_it) { - return false; - } - } - return true; - } -}; - } // namespace vkcom -namespace std { -template<> -struct hash { - uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash; } -}; -} // namespace std - namespace vkcom { Status fast_read_file_utf8(const string &file_name, string *file_content) { @@ -96,10 +61,6 @@ string token2word(const vector &source, return encode_utf8(res); } -bool is_space(uint32_t ch) { - return (ch < 256 && isspace(ch)) || (ch == SPACE_TOKEN); -} - uint64_t int2comb(uint32_t a, uint32_t b) { return (static_cast(a) << 32u) + b; } @@ -133,52 +94,6 @@ struct MergeCandidate { } }; -struct UTF8Iterator { - UTF8Iterator(char* begin, char* end): begin(begin), end(end) {} - - UTF8Iterator operator++() { - if (!state) { - parse(); - } - begin += utf8_len; - state = false; - return *this; - } - - uint32_t operator*() { - if (!state) { - parse(); - } - return code_point; - } - - char* get_ptr() { - return begin; - } - uint64_t get_utf8_len() { - return utf8_len; - } - - bool empty() { - assert(begin <= end); - return begin == end; - } - private: - char *begin, *end; - uint32_t code_point = 0; - uint64_t utf8_len = 0; - bool state = false; - void parse() { - if (state) { - return; - } - assert(!empty()); - code_point = chars_to_utf8(begin, end - begin, &utf8_len); - state = true; - } -}; - - struct Position { uint64_t word_id, pos_id; @@ -469,7 +384,8 @@ flat_hash_map compute_word_count( char* begin_of_word = utf8_iter.get_ptr(); for (; !utf8_iter.empty() && !is_space(*utf8_iter); ++utf8_iter); char* end_of_word = utf8_iter.get_ptr(); - VectorSegment word_hash(begin_of_word, end_of_word); + VectorSegmentBuilder word_hash_builder(begin_of_word, end_of_word); + VectorSegment word_hash = word_hash_builder.finish(); auto it = hash2wordcnt.find(word_hash); if (it == hash2wordcnt.end()) { word.clear(); diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index 99464a2..151ff43 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -3,7 +3,7 @@ #include #include #include -#include "third_party/flat_hash_map.h" +#include "third_party/flat_hash_map/flat_hash_map/flat_hash_map.h" #include "utils.h" @@ -19,8 +19,6 @@ enum OutputType { ID, SUBWORD }; Status train_bpe(const std::string &input_path, const std::string &model_path, int vocab_size, BpeConfig config); -void print_vocab(const std::string &model_path, bool verbose); - class BaseEncoder { public: BPEState bpe_state; @@ -83,4 +81,4 @@ class BaseEncoder { ) const; }; -} // namespace vkcom +} // namespace vkcom diff --git a/youtokentome/cpp/third_party/LICENSE b/youtokentome/cpp/third_party/flat_hash_map/LICENSE similarity index 100% rename from youtokentome/cpp/third_party/LICENSE rename to youtokentome/cpp/third_party/flat_hash_map/LICENSE diff --git a/youtokentome/cpp/third_party/flat_hash_map.h b/youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h similarity index 100% rename from youtokentome/cpp/third_party/flat_hash_map.h rename to youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h diff --git a/youtokentome/cpp/third_party/thread_pool/LICENSE b/youtokentome/cpp/third_party/thread_pool/LICENSE new file mode 100644 index 0000000..3b66ae6 --- /dev/null +++ b/youtokentome/cpp/third_party/thread_pool/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Ibragim Dzhiblavi + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/youtokentome/cpp/third_party/thread_pool/thread_pool.h b/youtokentome/cpp/third_party/thread_pool/thread_pool.h new file mode 100644 index 0000000..96884f6 --- /dev/null +++ b/youtokentome/cpp/third_party/thread_pool/thread_pool.h @@ -0,0 +1,91 @@ +// Copyright (c) 2023 Ibragim Dzhiblavi +// Modified 2023 Gleb Koveshnikov + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace vkcom { + +class ThreadPool { + public: + using Task = std::function; + + public: + ThreadPool(size_t thread_count) { + if (thread_count == 0) { + thread_count = static_cast(std::thread::hardware_concurrency()); + } + if (thread_count == 0) { + thread_count = 8; + } + for (size_t thread = 0; thread < thread_count; ++thread) { + threads_.emplace_back([this] { + while (!stop_.load(std::memory_order_relaxed)) { + std::unique_lock lock(mutex_); + work_cv_.wait(lock, [this] { + return stop_.load(std::memory_order_relaxed) || !task_queue_.empty(); + }); + if (stop_.load(std::memory_order_relaxed)) { + break; + } + if (task_queue_.empty()) { + continue; + } + ++active_tasks_; + auto task = std::move(task_queue_.front()); + task_queue_.pop(); + lock.unlock(); + task(); + lock.lock(); + --active_tasks_; + complete_cv_.notify_one(); + } + }); + } + } + + ~ThreadPool() { + stop_.store(true, std::memory_order_relaxed); + work_cv_.notify_all(); + for (auto &thread : threads_) { + if (thread.joinable()) { + thread.join(); + } + } + } + + void submit(Task &&task) { + { + std::lock_guard lg(mutex_); + task_queue_.emplace(std::move(task)); + } + work_cv_.notify_one(); + } + + void waitCompletion() { + std::unique_lock lock(mutex_); + if (active_tasks_ != 0 || !task_queue_.empty()) { + complete_cv_.wait(lock, [this] { return active_tasks_ == 0 && task_queue_.empty(); }); + } + } + + [[nodiscard]] size_t maxThreads() const noexcept { return threads_.size(); } + + private: + std::atomic stop_{false}; + size_t active_tasks_{0}; + std::mutex mutex_; + std::condition_variable work_cv_; + std::condition_variable complete_cv_; + std::vector threads_; + std::queue task_queue_; +}; + +} // namespace vkcom \ No newline at end of file diff --git a/youtokentome/cpp/utf8.cpp b/youtokentome/cpp/utf8.cpp index 3a67172..7334763 100644 --- a/youtokentome/cpp/utf8.cpp +++ b/youtokentome/cpp/utf8.cpp @@ -1,15 +1,28 @@ #include "utf8.h" -#include #include -#include -#include -#include "utils.h" namespace vkcom { -using std::string; -using std::vector; +bool is_space(uint32_t ch) { + return (ch < 256 && std::isspace(static_cast(ch))) || (ch == SPACE_TOKEN); +} + +bool is_punctuation(uint32_t ch) { + return (ch < 256 && std::ispunct(static_cast(ch))) || ch == 183 || ch == 171 + || ch == 187 || ch == 8249 || ch == 8250 || (8208 <= ch && ch <= 8248); +} + +bool is_chinese(uint32_t ch) { + if ((ch >= 0x4E00 && ch <= 0x9FFF) || (ch >= 0x3400 && ch <= 0x4DBF) + || (ch >= 0x20000 && ch <= 0x2A6DF) || (ch >= 0x2A700 && ch <= 0x2B73F) + || (ch >= 0x2B740 && ch <= 0x2B81F) || (ch >= 0x2B820 && ch <= 0x2CEAF) + || (ch >= 0xF900 && ch <= 0xFAFF) || (ch >= 0x2F800 && ch <= 0x2FA1F)) { + return true; + } + return false; +} +bool is_spacing_char(uint32_t ch) { return is_space(ch) || is_punctuation(ch) || is_chinese(ch); } bool check_byte(char x) { return (static_cast(x) & 0xc0u) == 0x80u; } @@ -73,7 +86,13 @@ uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) { return INVALID_UNICODE; } -void utf8_to_chars(uint32_t x, std::back_insert_iterator it) { +bool starts_with_space(const char *begin, int64_t size) { + uint64_t len = 0; + uint32_t symbol = chars_to_utf8(begin, size, &len); + return is_space(symbol); +} + +void utf8_to_chars(uint32_t x, std::back_insert_iterator it) { assert(check_codepoint(x)); if (x <= 0x7f) { @@ -100,16 +119,16 @@ void utf8_to_chars(uint32_t x, std::back_insert_iterator it) { *(it++) = 0x80u | (x & 0x3fu); } -string encode_utf8(const vector& text) { - string utf8_text; +std::string encode_utf8(const std::vector& text) { + std::string utf8_text; for (const uint32_t c : text) { utf8_to_chars(c, std::back_inserter(utf8_text)); } return utf8_text; } -vector decode_utf8(const char* begin, const char* end) { - vector decoded_text; +std::vector decode_utf8(const char* begin, const char* end) { + std::vector decoded_text; uint64_t utf8_len = 0; bool invalid_input = false; for (; begin < end; begin += utf8_len) { @@ -127,7 +146,7 @@ vector decode_utf8(const char* begin, const char* end) { return decoded_text; } -vector decode_utf8(const string& utf8_text) { +std::vector decode_utf8(const std::string& utf8_text) { return decode_utf8(utf8_text.data(), utf8_text.data() + utf8_text.size()); } diff --git a/youtokentome/cpp/utf8.h b/youtokentome/cpp/utf8.h index ec34831..ae31a26 100644 --- a/youtokentome/cpp/utf8.h +++ b/youtokentome/cpp/utf8.h @@ -1,23 +1,88 @@ #pragma once -#include "utils.h" +#include +#include +#include namespace vkcom { -constexpr static uint32_t INVALID_UNICODE = 0x0fffffff; +const uint32_t SPACE_TOKEN = 9601; + +constexpr static uint32_t INVALID_UNICODE = 0x110000; + +bool is_space(uint32_t ch); + +bool is_punctuation(uint32_t ch); + +bool is_chinese_char(uint32_t ch); + +bool is_spacing_char(uint32_t ch); + +bool check_byte(char x); + +bool check_symbol_start(char x); + +bool check_codepoint(uint32_t x); + +uint64_t utf_length(char ch); uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len); +bool starts_with_space(const char *begin, int64_t size); + +void utf8_to_chars(uint32_t x, std::back_insert_iterator it); + std::string encode_utf8(const std::vector &utext); std::vector decode_utf8(const char *begin, const char *end); std::vector decode_utf8(const std::string &utf8_text); +struct UTF8Iterator { + UTF8Iterator(char* begin, char* end): begin(begin), end(end) {} + + UTF8Iterator operator++() { + if (!state) { + parse(); + } + begin += utf8_len; + state = false; + return *this; + } + uint32_t operator*() { + if (!state) { + parse(); + } + return code_point; + } + char* get_ptr() { + return begin; + } + uint64_t get_utf8_len() { + return utf8_len; + } -} // namespace vkcom + bool empty() { + assert(begin <= end); + return begin == end; + } +private: + char *begin, *end; + uint32_t code_point = 0; + uint64_t utf8_len = 0; + bool state = false; + void parse() { + if (state) { + return; + } + assert(!empty()); + code_point = chars_to_utf8(begin, end - begin, &utf8_len); + state = true; + } +}; +} // namespace vkcom \ No newline at end of file diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 768a817..840b68c 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -6,8 +6,6 @@ #include namespace vkcom { -using std::string; -using std::vector; void SpecialTokens::dump(std::ofstream &fout) { fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id @@ -49,7 +47,7 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const { BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {} -void BPEState::dump(const string &file_name) { +void BPEState::dump(const std::string &file_name) { std::ofstream fout(file_name, std::ios::out); if (fout.fail()) { std::cerr << "Can't open file: " << file_name << std::endl; @@ -67,7 +65,7 @@ void BPEState::dump(const string &file_name) { fout.close(); } -Status BPEState::load(const string &file_name) { +Status BPEState::load(const std::string &file_name) { char2id.clear(); rules.clear(); std::ifstream fin(file_name, std::ios::in); @@ -98,16 +96,29 @@ BpeConfig::BpeConfig(double _character_coverage, int _n_threads, n_threads(_n_threads), special_tokens(_special_tokens) {} -vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) { - vector sentences; - string s; - while (*processed < batch_limit && getline(std::cin, s)) { +std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) { + std::vector sentences; + std::string s; + while (*processed < batch_limit && std::getline(std::cin, s)) { *processed += s.size(); sentences.push_back(std::move(s)); } return sentences; } +std::string read_file(const std::string& path) { + std::ifstream t("file.txt"); + std::string str; + + t.seekg(0, std::ios::end); + str.reserve(t.tellg()); + t.seekg(0, std::ios::beg); + + str.assign((std::istreambuf_iterator(t)), + std::istreambuf_iterator()); + return str; +} + Status::Status(int code, std::string message) : code(code), message(std::move(message)) {} const std::string &Status::error_message() const { @@ -116,4 +127,5 @@ const std::string &Status::error_message() const { bool Status::ok() const { return code == 0; } + } // namespace vkcom diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index ce802d5..f682e73 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -3,10 +3,9 @@ #include #include #include -#include "third_party/flat_hash_map.h" +#include "third_party/flat_hash_map/flat_hash_map.h" namespace vkcom { -const uint32_t SPACE_TOKEN = 9601; struct BPE_Rule { // x + y -> z @@ -85,21 +84,105 @@ struct EncodingConfig { double dropout_prob; }; -bool is_space(uint32_t ch); - std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed); +std::string read_file(const std::string& path); + +template +void write_to_stdout(const std::vector &items, bool flush) { + for (const auto &item : items) { + std::cout << item << " "; + } + std::cout << "\n"; + if (flush) { + std::cout << std::flush; + } +} + template void write_to_stdout(const std::vector> &sentences, bool flush) { for (const auto &sentence : sentences) { - for (const auto &token : sentence) { - std::cout << token << " "; - } - std::cout << "\n"; + write_to_stdout(sentence, false); } if (flush) { std::cout << std::flush; } } -} // namespace vkcom +class VectorSegmentBuilder; + +struct VectorSegment { + private: + friend class VectorSegmentBuilder; + + const uint32_t *begin_; + const uint32_t *end_; + const uint64_t hash_; + + VectorSegment(const uint32_t *begin, const uint32_t *end, uint64_t hash) + : begin_(begin), end_(end), hash_(hash) {} + + public: + bool operator==(const VectorSegment &other) const { + if (other.hash() != hash() || end_ - begin_ != other.end_ - other.begin_) { + return false; + } + for (auto it = begin_, other_it = other.begin_; it != end_; it++, other_it++) { + if (*it != *other_it) { + return false; + } + } + return true; + } + + uint64_t hash() const { return hash_; } +}; + +class VectorSegmentBuilder { + private: + constexpr static uint64_t MOD = 2032191299; + constexpr static uint64_t P = 726328703; + + const uint32_t *begin_; + const uint32_t *end_; + std::vector prefix_hash_; + + public: + VectorSegmentBuilder(const std::vector &segment) + : VectorSegmentBuilder(segment.data(), segment.data() + segment.size()) {} + + VectorSegmentBuilder(const uint32_t *begin, const uint32_t *end) : begin_(begin), end_(end) { + uint64_t hash = 0; + prefix_hash_.reserve(static_cast(end - begin)); + for (const uint32_t *it = begin_; it != end_; it++) { + hash = (hash * P + *it) % MOD; + prefix_hash_.push_back(hash); + } + } + + VectorSegment finish() const { return VectorSegment(begin_, end_, hash()); } + + size_t size() const { return prefix_hash_.size(); } + + bool empty() const { return prefix_hash_.empty(); } + + uint64_t hash() const { return prefix_hash_.empty() ? 0 : prefix_hash_.back(); } + + void pop_back() noexcept { + if (!prefix_hash_.empty()) { + prefix_hash_.pop_back(); + --end_; + } + } +}; + +} // namespace vkcom + +namespace std { + +template <> +struct hash { + uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash(); } +}; + +} // namespace std \ No newline at end of file diff --git a/youtokentome/cpp/wordpiece.cpp b/youtokentome/cpp/wordpiece.cpp new file mode 100644 index 0000000..5e18dc5 --- /dev/null +++ b/youtokentome/cpp/wordpiece.cpp @@ -0,0 +1,234 @@ +#include "wordpiece.h" + +#include +#include +#include +#include +#include +#include + +#include "third_party/flat_hash_map/flat_hash_map.h" +#include "third_party/thread_pool/thread_pool.hpp" +#include "utf8.h" + +namespace { + +struct WordPieceToken { + explicit WordPieceToken(const std::string &encoded_word) + : is_prefix(true), is_special(false), is_malformed(false), word(vkcom::decode_utf8(encoded_word)) { + if (isSuffixVocab(word)) { + is_prefix = false; + word.erase(word.begin(), word.begin() + 2); + } else if (isSpecialToken(word)) { + is_special = true; + } + + bool all_punctuation = true; + for (uint32_t code_point : word) { + if (code_point == vkcom::INVALID_UNICODE) { + is_malformed = true; + } + if (!vkcom::is_punctuation(code_point) && !vkcom::is_space(code_point)) { + all_punctuation = false; + } + } + if (word.empty()) { + throw std::runtime_error("Vocab word is empty"); + } + if (is_malformed || (all_punctuation && word.size() > 1)) { + is_malformed = true; + std::cerr << "Vocab word is malformed: " << encoded_word << std::endl; + } +} + + bool is_prefix; + bool is_special; + bool is_malformed; + std::vector word; +}; + +struct WordPieceVocabulary { + static constexpr int kDefaultUnkTokenId = -1; + + std::vector tokens; + int unk_token_id = kDefaultUnkTokenId; +}; + +WordPieceVocabulary readVocabFromFile(const std::string &file) { + WordPieceVocabulary vocab_utf8; + std::ifstream fin(file); + std::string word; + int token_id = 0; + while (std::getline(fin, word)) { + if (word == kUnkTokenIdStr) { + vocab_utf8.unk_token_id = token_id; + } + WordPieceToken token(word); + vocab_utf8.tokens.push_back(std::move(token)); + ++token_id; + } + return vocab_utf8; +} + +vkcom::ThreadPool &globalThreadPool(size_t n_threads) { + static vkcom::ThreadPool thread_pool(n_threads); + return thread_pool; +} + +std::vector encodeWordPieceImpl(const std::vector &text, + const WordPieceVocabulary &vocab) { + using WordMap = std::unordered_map; + WordMap prefix_to_id; // no ## in word prefix + WordMap suffix_to_id; // ## in word prefix + + size_t max_len = 0; + for (size_t i = 0; i < vocab.tokens.size(); i++) { + const auto &token = vocab.tokens[i]; + if (token.is_special || token.is_malformed) { + continue; + } + max_len = std::max(max_len, token.word.size()); + vkcom::VectorSegmentBuilder segment(token.word); + WordMap *word_to_id = token.is_prefix ? &prefix_to_id : &suffix_to_id; + (*word_to_id)[segment.finish()] = static_cast(i); + } + max_len = std::min(max_len, text.size()); + + const auto is_word_prefix = [&text](size_t index) { + return index == 0 || vkcom::is_spacing_char(text[index]) + || vkcom::is_spacing_char(text[index - 1]); + }; + + const auto worker = [&, unk_token_id = vocab.unk_token_id](size_t begin, size_t end) { + std::vector token_ids; + token_ids.reserve((end - begin) / max_len + 1); + + while (begin != end && vkcom::is_space(text[begin])) { + ++begin; + } + + size_t tokens_since_prefix = 0; + + while (begin != end) { + size_t word_len = 1; + if (!vkcom::is_punctuation(text[begin])) { + while (word_len < std::min(max_len, end - begin) + && !vkcom::is_spacing_char(text[begin + word_len])) { + ++word_len; + } + } + + const uint32_t *segment_begin = text.data() + static_cast(begin); + const uint32_t *segment_end = segment_begin + static_cast(word_len); + const WordMap *word_to_id = is_word_prefix(begin) ? &prefix_to_id : &suffix_to_id; + + vkcom::VectorSegmentBuilder segment(segment_begin, segment_end); + while (!segment.empty()) { + auto it = word_to_id->find(segment.finish()); + if (it != word_to_id->end()) { + ++tokens_since_prefix; + token_ids.push_back(it->second); + begin += segment.size(); + break; + } else { + segment.pop_back(); + } + } + + if (segment.empty()) { + while (tokens_since_prefix > 0) { + token_ids.pop_back(); + --tokens_since_prefix; + } + token_ids.push_back(unk_token_id); + begin += word_len; + while (begin != end && !is_word_prefix(begin)) { + ++begin; + } + } else if (begin != end && is_word_prefix(begin)) { + tokens_since_prefix = 0; + } + + while (begin != end && vkcom::is_space(text[begin])) { + ++begin; + } + } + + return token_ids; + }; + + static constexpr size_t kWorkBatch = 1'000'000; + std::vector token_ids; + if (text.size() < 2 * kWorkBatch) { + token_ids = worker(0, text.size()); + } else { + const size_t thread_count + = std::min(globalThreadPool().maxThreads(), text.size() / kWorkBatch); + const size_t work_batch = text.size() / thread_count + 1; + std::vector> per_thread_token_ids(thread_count); + size_t work_begin = 0; + for (size_t thread_id = 0; thread_id < thread_count && work_begin < text.size(); thread_id++) { + size_t work_end = std::min(text.size(), work_begin + work_batch); + while (work_end < text.size() && !vkcom::is_space(text[work_end])) { + ++work_end; + } + globalThreadPool().submit( + [thread_id, work_begin, work_end, &per_thread_token_ids, &worker] { + per_thread_token_ids[thread_id] = worker(work_begin, work_end); + }); + work_begin = work_end; + } + + globalThreadPool().waitCompletion(); + + size_t token_count = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + token_count += per_thread_token_ids[thread_id].size(); + } + token_ids.resize(token_count); + work_begin = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + std::vector &segment = per_thread_token_ids[thread_id]; + if (!segment.empty()) { + std::memcpy(token_ids.data() + work_begin, segment.data(), segment.size() * sizeof(int)); + work_begin += segment.size(); + } + } + } + + return token_ids; +} + +std::vector +encodeWordPiece(const char *text, size_t size, const WordPieceVocabulary &vocab) { + if (size == 0) { + return {}; + } + const std::vector text_utf8 = utils::parseText(text, size, globalThreadPool()); + return encodeWordPieceImpl(text_utf8, vocab); +} + +} // namespace + +namespace vkcom { + +/*std::vector encode_wordpiece(const std::string& input_path, const std::string& vocab_path) { + const WordPieceVocabulary vocab_utf8 = readVocabFromFile(vocab_file); + // TODO: use mapped file + std::string text = read_file(input_path); + return encodeWordPiece(text.data(), text.size(), vocab_utf8); +} + +Status encode_wordpiece_cli(const std::string& input_path, const std::string& vocab_path) { + try { + std::vector ids = encode_wordpiece(input_path, vocab_path); + write_to_stdout(ids, true); + return Status(); + } catch (const std::exception& ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } +}*/ + +} // namespace vkcom \ No newline at end of file diff --git a/youtokentome/cpp/wordpiece.h b/youtokentome/cpp/wordpiece.h new file mode 100644 index 0000000..98375d5 --- /dev/null +++ b/youtokentome/cpp/wordpiece.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +namespace vkcom::wordpiece { + +/*Status encode_as_ids(const std::string &text, + const std::vector& vocab, std::vector *ids); + +Status encode_as_subwords(const std::string &text, + const std::vector& vocab, + std::vector> *subwords); + +Status encode_wordpiece_cli(const std::string& vocab_path);*/ + +} // namespace vkcom::wordpiece \ No newline at end of file From fd1398965b6af0457be47243840847fbb719810e Mon Sep 17 00:00:00 2001 From: Gleb Koveshnikov Date: Sat, 1 Apr 2023 14:14:04 +0400 Subject: [PATCH 02/10] cleanup --- MANIFEST.in | 2 + tests/unit_tests/bpe/stress_test.cpp | 1 - tests/unit_tests/bpe/stress_test.h | 20 ------- tests/unit_tests/bpe/test_python_api.py | 1 + tests/unit_tests/wordpiece/test_cli.py | 3 ++ tests/unit_tests/wordpiece/test_manual.py | 4 ++ tests/unit_tests/wordpiece/test_python_api.py | 4 ++ youtokentome/cpp/bpe.cpp | 4 -- youtokentome/cpp/bpe.h | 19 ++++--- youtokentome/cpp/utf8.h | 52 +----------------- youtokentome/cpp/utils.cpp | 11 ---- youtokentome/cpp/utils.h | 54 +++++++++++-------- youtokentome/cpp/wordpiece.cpp | 4 +- youtokentome/cpp/yttm.pyx | 3 +- 14 files changed, 62 insertions(+), 120 deletions(-) delete mode 100644 tests/unit_tests/bpe/stress_test.h diff --git a/MANIFEST.in b/MANIFEST.in index 92d8a80..4ade57c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,10 @@ include youtokentome/cpp/utils.h include youtokentome/cpp/bpe.h include youtokentome/cpp/utf8.h +include youtokentome/cpp/wordpiece.h include youtokentome/cpp/yttm.pyx include youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h +include youtokentome/cpp/third_party/thread_pool/thread_pool.h include youtokentome/cpp/third_party/LICENSE include LICENSE include README.md diff --git a/tests/unit_tests/bpe/stress_test.cpp b/tests/unit_tests/bpe/stress_test.cpp index 91a7c63..db4b2a2 100644 --- a/tests/unit_tests/bpe/stress_test.cpp +++ b/tests/unit_tests/bpe/stress_test.cpp @@ -5,7 +5,6 @@ #include #include #include -#include "stress_test.h" #include "../../youtokentome/cpp/utils.h" #include "../../youtokentome/cpp/bpe.h" diff --git a/tests/unit_tests/bpe/stress_test.h b/tests/unit_tests/bpe/stress_test.h deleted file mode 100644 index e5d80c7..0000000 --- a/tests/unit_tests/bpe/stress_test.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "../../youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h" -#include "../../youtokentome/cpp/utils.h" - -namespace vkcom { - -flat_hash_map -compute_alphabet_helper(const flat_hash_map &char_cnt, - uint64_t data_len, - flat_hash_set &removed_chars, - const BpeConfig &bpe_config); - -Status learn_bpe_from_string(std::string &text_utf8, - int n_tokens, - const std::string &output_file, - BpeConfig bpe_config, - BPEState *bpe_state); - -} // namespace vkcom diff --git a/tests/unit_tests/bpe/test_python_api.py b/tests/unit_tests/bpe/test_python_api.py index 4fce4c5..2bfe10f 100644 --- a/tests/unit_tests/bpe/test_python_api.py +++ b/tests/unit_tests/bpe/test_python_api.py @@ -2,6 +2,7 @@ import random import youtokentome as yttm + from utils_for_testing import ( BASE_MODEL_FILE, RENAME_ID_MODEL_FILE, diff --git a/tests/unit_tests/wordpiece/test_cli.py b/tests/unit_tests/wordpiece/test_cli.py index e69de29..49960d2 100644 --- a/tests/unit_tests/wordpiece/test_cli.py +++ b/tests/unit_tests/wordpiece/test_cli.py @@ -0,0 +1,3 @@ +import os +import random +from subprocess import run \ No newline at end of file diff --git a/tests/unit_tests/wordpiece/test_manual.py b/tests/unit_tests/wordpiece/test_manual.py index e69de29..b49d2ab 100644 --- a/tests/unit_tests/wordpiece/test_manual.py +++ b/tests/unit_tests/wordpiece/test_manual.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +import os + +import youtokentome as yttm \ No newline at end of file diff --git a/tests/unit_tests/wordpiece/test_python_api.py b/tests/unit_tests/wordpiece/test_python_api.py index e69de29..a8acf38 100644 --- a/tests/unit_tests/wordpiece/test_python_api.py +++ b/tests/unit_tests/wordpiece/test_python_api.py @@ -0,0 +1,4 @@ +import os +import random + +import youtokentome as yttm \ No newline at end of file diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index ea76416..da91328 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -25,10 +25,6 @@ namespace vkcom { -} // namespace vkcom - -namespace vkcom { - Status fast_read_file_utf8(const std::string &file_name, std::string *file_content) { static const int buf_size = 1000000; *file_content = ""; diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index 151ff43..178f306 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -9,16 +9,21 @@ namespace vkcom { -const std::string UNK_TOKEN = ""; -const std::string PAD_TOKEN = ""; -const std::string BOS_TOKEN = ""; -const std::string EOS_TOKEN = ""; - -enum OutputType { ID, SUBWORD }; - Status train_bpe(const std::string &input_path, const std::string &model_path, int vocab_size, BpeConfig config); +Status learn_bpe_from_string(std::string &text_utf8, + int n_tokens, + const std::string &output_file, + BpeConfig bpe_config, + BPEState *bpe_state); + +flat_hash_map +compute_alphabet_helper(const flat_hash_map &char_cnt, + uint64_t data_len, + flat_hash_set &removed_chars, + const BpeConfig &bpe_config); + class BaseEncoder { public: BPEState bpe_state; diff --git a/youtokentome/cpp/utf8.h b/youtokentome/cpp/utf8.h index a0e1e4b..353e3f7 100644 --- a/youtokentome/cpp/utf8.h +++ b/youtokentome/cpp/utf8.h @@ -1,14 +1,8 @@ #pragma once -<<<<<<< HEAD #include #include #include -======= -#include -#include -#include ->>>>>>> master namespace vkcom { @@ -34,14 +28,11 @@ uint64_t utf_length(char ch); uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len); -<<<<<<< HEAD bool starts_with_space(const char *begin, int64_t size); -======= ->>>>>>> master void utf8_to_chars(uint32_t x, std::back_insert_iterator it); -std::string encode_utf8(const std::vector &utext); +std::string encode_utf8(const std::vector &text); std::vector decode_utf8(const char *begin, const char *end); @@ -49,7 +40,6 @@ std::vector decode_utf8(const std::string &utf8_text); struct UTF8Iterator { UTF8Iterator(char* begin, char* end): begin(begin), end(end) {} -<<<<<<< HEAD UTF8Iterator operator++() { if (!state) { @@ -70,6 +60,7 @@ struct UTF8Iterator { char* get_ptr() { return begin; } + uint64_t get_utf8_len() { return utf8_len; } @@ -85,41 +76,6 @@ struct UTF8Iterator { uint64_t utf8_len = 0; bool state = false; -======= - - UTF8Iterator operator++() { - if (!state) { - parse(); - } - begin += utf8_len; - state = false; - return *this; - } - - uint32_t operator*() { - if (!state) { - parse(); - } - return code_point; - } - - char* get_ptr() { - return begin; - } - uint64_t get_utf8_len() { - return utf8_len; - } - - bool empty() { - assert(begin <= end); - return begin == end; - } -private: - char *begin, *end; - uint32_t code_point = 0; - uint64_t utf8_len = 0; - bool state = false; ->>>>>>> master void parse() { if (state) { return; @@ -130,8 +86,4 @@ struct UTF8Iterator { } }; -<<<<<<< HEAD -} // namespace vkcom -======= } // namespace vkcom ->>>>>>> master diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 2f6ee19..840b68c 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -96,21 +96,10 @@ BpeConfig::BpeConfig(double _character_coverage, int _n_threads, n_threads(_n_threads), special_tokens(_special_tokens) {} -<<<<<<< HEAD std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) { std::vector sentences; std::string s; while (*processed < batch_limit && std::getline(std::cin, s)) { -======= -bool is_space(uint32_t ch) { - return (ch < 256 && isspace(ch)) || (ch == SPACE_TOKEN); -} - -std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) { - std::vector sentences; - std::string s; - while (*processed < batch_limit && getline(std::cin, s)) { ->>>>>>> master *processed += s.size(); sentences.push_back(std::move(s)); } diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index f682e73..f947f37 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -3,21 +3,31 @@ #include #include #include + #include "third_party/flat_hash_map/flat_hash_map.h" namespace vkcom { -struct BPE_Rule { - // x + y -> z - uint32_t x{0}; - uint32_t y{0}; - uint32_t z{0}; +const std::string UNK_TOKEN = ""; +const std::string PAD_TOKEN = ""; +const std::string BOS_TOKEN = ""; +const std::string EOS_TOKEN = ""; - BPE_Rule() = default; +enum OutputType { ID, SUBWORD }; - BPE_Rule(uint32_t x, uint32_t y, uint32_t z); +struct DecodeResult { + std::vector ids; + std::vector pieces; +}; - bool operator==(const BPE_Rule &other) const; +struct Status { + int code{0}; + std::string message; + Status() = default; + Status(int code, std::string message); + + const std::string &error_message() const; + bool ok() const; }; struct SpecialTokens { @@ -41,6 +51,19 @@ struct SpecialTokens { uint64_t n_special_tokens() const; }; +struct BPE_Rule { + // x + y -> z + uint32_t x{0}; + uint32_t y{0}; + uint32_t z{0}; + + BPE_Rule() = default; + + BPE_Rule(uint32_t x, uint32_t y, uint32_t z); + + bool operator==(const BPE_Rule &other) const; +}; + struct BpeConfig { double character_coverage = 1; int n_threads = 0; @@ -52,16 +75,6 @@ struct BpeConfig { const SpecialTokens &special_tokens); }; -struct Status { - int code{0}; - std::string message; - Status() = default; - Status(int code, std::string message); - - const std::string &error_message() const; - bool ok() const; -}; - struct BPEState { flat_hash_map char2id; std::vector rules; @@ -72,11 +85,6 @@ struct BPEState { Status load(const std::string &file_name); }; -struct DecodeResult { - std::vector ids; - std::vector pieces; -}; - struct EncodingConfig { bool bos; bool eos; diff --git a/youtokentome/cpp/wordpiece.cpp b/youtokentome/cpp/wordpiece.cpp index 5e18dc5..1d05b82 100644 --- a/youtokentome/cpp/wordpiece.cpp +++ b/youtokentome/cpp/wordpiece.cpp @@ -210,7 +210,7 @@ encodeWordPiece(const char *text, size_t size, const WordPieceVocabulary &vocab) } // namespace -namespace vkcom { +namespace vkcom::wordpiece { /*std::vector encode_wordpiece(const std::string& input_path, const std::string& vocab_path) { const WordPieceVocabulary vocab_utf8 = readVocabFromFile(vocab_file); @@ -231,4 +231,4 @@ Status encode_wordpiece_cli(const std::string& input_path, const std::string& vo } }*/ -} // namespace vkcom \ No newline at end of file +} // namespace vkcom::wordpiece \ No newline at end of file diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index 1d7774d..3bddb96 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -7,8 +7,7 @@ from pathlib import Path from typing import Collection -cdef extern from "bpe.h" namespace "vkcom": - +cdef extern from "utils.h" namespace "vkcom": cdef cppclass SpecialTokens: int pad_id int unk_id From 8e9df07f91727235db0464aa1e97bfa73916b064 Mon Sep 17 00:00:00 2001 From: Gleb Koveshnikov Date: Sat, 1 Apr 2023 14:42:50 +0400 Subject: [PATCH 03/10] cleanup 2 --- README.md | 24 ++++---- tests/speed_test/bpe.py | 2 +- tests/speed_test/wordpiece.py | 4 +- youtokentome/cpp/bpe.cpp | 69 ++++++++++++++++------ youtokentome/cpp/bpe.h | 45 ++++++++++++++- youtokentome/cpp/utf8.h | 1 + youtokentome/cpp/utils.cpp | 105 +++++++++------------------------- youtokentome/cpp/utils.h | 45 +-------------- 8 files changed, 142 insertions(+), 153 deletions(-) diff --git a/README.md b/README.md index a4ec64f..a9a2f73 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ # YouTokenToMe YouTokenToMe is an unsupervised text tokenizer focused on computational efficiency. It currently contains the fastest implementations of: -- Byte Pair Encoding (BPE) [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162)], [benchmark results](benchmark.md); -- WordPiece [[Song et al.](https://arxiv.org/abs/2012.15524)], [benchmark results](benchmark.md). +- Byte Pair Encoding (BPE) [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162)], [benchmark results](benchmark_bpe.md); +- WordPiece [[Song et al.](https://arxiv.org/abs/2012.15524)], [benchmark results](benchmark_wordpiece.md). Key advantages: @@ -33,7 +33,7 @@ For example, the phrase ```Blazingly fast tokenization!``` can be tokenized into Algorighm properties: * Currently supports tokenizer only, but not training -* Time complexity is `O(NM)`, where `N` is the length of tokenized data and `M` is the max length of word in vocabulary +* Time complexity is `O(NM^2)`, where `N` is the length of tokenized data and `M` is the max length of word in vocabulary ## Installation @@ -41,9 +41,9 @@ Algorighm properties: pip install youtokentome ``` -## Python interface +## Python BPE interface -### BPE Example +### Example ```python import random @@ -74,12 +74,8 @@ bpe = yttm.BPE(model=model_path) print(bpe.encode([test_text], output_type=yttm.OutputType.ID)) print(bpe.encode([test_text], output_type=yttm.OutputType.SUBWORD)) ``` - -### WordPiece Example - -TODO -### BPE Methods +### Methods Class `youtokentome.BPE` has the following methods: #### constructor @@ -198,7 +194,13 @@ Convert each id to subword and concatenate with space symbol. **Returns:** List of strings. -### WordPiece methods +## Python WordPiece interface + +### Example + +TODO + +### Methods TODO diff --git a/tests/speed_test/bpe.py b/tests/speed_test/bpe.py index 68adde8..e695941 100644 --- a/tests/speed_test/bpe.py +++ b/tests/speed_test/bpe.py @@ -15,7 +15,7 @@ YOU_TOKEN_TO_ME = "YouTokenToMe" SENTENCE_PIECE = "SentencePiece" FAST_BPE = "fastBPE" -HUGGING_FACE= "Hugging_Face" +HUGGING_FACE= "Hugging Face" PATH_TO_FASTBPE = "./fastBPE" diff --git a/tests/speed_test/wordpiece.py b/tests/speed_test/wordpiece.py index 0502cfd..6ab72dc 100644 --- a/tests/speed_test/wordpiece.py +++ b/tests/speed_test/wordpiece.py @@ -12,9 +12,9 @@ YOU_TOKEN_TO_ME = "YouTokenToMe" -HUGGING_FACE = 'Hugging_Face' +HUGGING_FACE = 'Hugging Face' KERAS = 'Keras' -TENSORFLOW = 'Tensorflow' +TENSORFLOW = 'TensorFlow' TORCH = 'Torch' ALGORITHMS = [YOU_TOKEN_TO_ME, HUGGING_FACE, KERAS, TENSORFLOW, TORCH] diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index da91328..52a27b3 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -1,5 +1,3 @@ -#include - #include "bpe.h" #include @@ -7,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -25,25 +24,61 @@ namespace vkcom { -Status fast_read_file_utf8(const std::string &file_name, std::string *file_content) { - static const int buf_size = 1000000; - *file_content = ""; - auto fin = fopen(file_name.data(), "rb"); - if (fin == nullptr) { - return Status(1, "Failed to open file: " + file_name); +bool BPE_Rule::operator==(const BPE_Rule &other) const { + return x == other.x && y == other.y && z == other.z; +} + +BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {} + +void BPEState::dump(const std::string &file_name) { + std::ofstream fout(file_name, std::ios::out); + if (fout.fail()) { + std::cerr << "Can't open file: " << file_name << std::endl; + assert(false); } - while (true) { - uint64_t cur_size = file_content->size(); - file_content->resize(cur_size + buf_size); - int buf_len = fread((void *) (file_content->data() + cur_size), 1, buf_size, fin); - if (buf_len < buf_size) { - file_content->resize(file_content->size() - (buf_size - buf_len)); - fclose(fin); - return Status(); - } + fout << char2id.size() << " " << rules.size() << std::endl; + for (auto s : char2id) { + fout << s.first << " " << s.second << std::endl; + } + + for (auto rule : rules) { + fout << rule.x << " " << rule.y << " " << rule.z << std::endl; } + special_tokens.dump(fout); + fout.close(); } +Status BPEState::load(const std::string &file_name) { + char2id.clear(); + rules.clear(); + std::ifstream fin(file_name, std::ios::in); + if (fin.fail()) { + return Status(1, "Can not open file with model: " + file_name); + } + int n, m; + fin >> n >> m; + for (int i = 0; i < n; i++) { + uint32_t inner_id; + uint32_t utf32_id; + fin >> inner_id >> utf32_id; + char2id[inner_id] = utf32_id; + } + for (int i = 0; i < m; i++) { + uint32_t x, y, z; + fin >> x >> y >> z; + rules.emplace_back(x, y, z); + } + special_tokens.load(fin); + fin.close(); + return Status(); +} + +BpeConfig::BpeConfig(double _character_coverage, int _n_threads, + const SpecialTokens &_special_tokens) + : character_coverage(_character_coverage), + n_threads(_n_threads), + special_tokens(_special_tokens) {} + std::string token2word(const std::vector &source, const flat_hash_map &id2char) { std::vector res; diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index 178f306..70094f8 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -1,14 +1,57 @@ #pragma once -#include #include #include +#include + #include "third_party/flat_hash_map/flat_hash_map/flat_hash_map.h" #include "utils.h" +// TODO: introduce vkcom::bpe namespace namespace vkcom { +struct BPE_Rule { + // x + y -> z + uint32_t x{0}; + uint32_t y{0}; + uint32_t z{0}; + + BPE_Rule() = default; + + BPE_Rule(uint32_t x, uint32_t y, uint32_t z); + + bool operator==(const BPE_Rule &other) const; +}; + +struct BpeConfig { + double character_coverage = 1; + int n_threads = 0; + SpecialTokens special_tokens; + + BpeConfig() = default; + + BpeConfig(double character_coverage, int n_threads, + const SpecialTokens &special_tokens); +}; + +struct BPEState { + flat_hash_map char2id; + std::vector rules; + SpecialTokens special_tokens; + + void dump(const std::string &file_name); + + Status load(const std::string &file_name); +}; + +struct EncodingConfig { + bool bos; + bool eos; + bool reverse; + double dropout_prob; +}; + Status train_bpe(const std::string &input_path, const std::string &model_path, int vocab_size, BpeConfig config); diff --git a/youtokentome/cpp/utf8.h b/youtokentome/cpp/utf8.h index 353e3f7..0358c3a 100644 --- a/youtokentome/cpp/utf8.h +++ b/youtokentome/cpp/utf8.h @@ -4,6 +4,7 @@ #include #include +// TODO: introduce vkcom::utf8 namespace namespace vkcom { const uint32_t SPACE_TOKEN = 9601; diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 840b68c..e7de2e1 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -1,12 +1,20 @@ #include "utils.h" -#include + #include -#include #include #include namespace vkcom { +Status::Status(int code, std::string message) : code(code), message(std::move(message)) {} + +const std::string &Status::error_message() const { + return message; +} +bool Status::ok() const { + return code == 0; +} + void SpecialTokens::dump(std::ofstream &fout) { fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id << std::endl; @@ -41,61 +49,6 @@ uint64_t SpecialTokens::n_special_tokens() const { SpecialTokens::SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id) : pad_id(pad_id), unk_id(unk_id), bos_id(bos_id), eos_id(eos_id) {} -bool BPE_Rule::operator==(const BPE_Rule &other) const { - return x == other.x && y == other.y && z == other.z; -} - -BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {} - -void BPEState::dump(const std::string &file_name) { - std::ofstream fout(file_name, std::ios::out); - if (fout.fail()) { - std::cerr << "Can't open file: " << file_name << std::endl; - assert(false); - } - fout << char2id.size() << " " << rules.size() << std::endl; - for (auto s : char2id) { - fout << s.first << " " << s.second << std::endl; - } - - for (auto rule : rules) { - fout << rule.x << " " << rule.y << " " << rule.z << std::endl; - } - special_tokens.dump(fout); - fout.close(); -} - -Status BPEState::load(const std::string &file_name) { - char2id.clear(); - rules.clear(); - std::ifstream fin(file_name, std::ios::in); - if (fin.fail()) { - return Status(1, "Can not open file with model: " + file_name); - } - int n, m; - fin >> n >> m; - for (int i = 0; i < n; i++) { - uint32_t inner_id; - uint32_t utf32_id; - fin >> inner_id >> utf32_id; - char2id[inner_id] = utf32_id; - } - for (int i = 0; i < m; i++) { - uint32_t x, y, z; - fin >> x >> y >> z; - rules.emplace_back(x, y, z); - } - special_tokens.load(fin); - fin.close(); - return Status(); -} - -BpeConfig::BpeConfig(double _character_coverage, int _n_threads, - const SpecialTokens &_special_tokens) - : character_coverage(_character_coverage), - n_threads(_n_threads), - special_tokens(_special_tokens) {} - std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) { std::vector sentences; std::string s; @@ -106,26 +59,24 @@ std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *p return sentences; } -std::string read_file(const std::string& path) { - std::ifstream t("file.txt"); - std::string str; - - t.seekg(0, std::ios::end); - str.reserve(t.tellg()); - t.seekg(0, std::ios::beg); - - str.assign((std::istreambuf_iterator(t)), - std::istreambuf_iterator()); - return str; -} - -Status::Status(int code, std::string message) : code(code), message(std::move(message)) {} - -const std::string &Status::error_message() const { - return message; -} -bool Status::ok() const { - return code == 0; +Status fast_read_file_utf8(const std::string &file_name, std::string *file_content) { + static const int buf_size = 1000000; + *file_content = ""; + // TODO: use ifstream and seekg+tellg+seekg to reserve + auto fin = fopen(file_name.data(), "rb"); + if (fin == nullptr) { + return Status(1, "Failed to open file: " + file_name); + } + while (true) { + uint64_t cur_size = file_content->size(); + file_content->resize(cur_size + buf_size); + int buf_len = fread((void *) (file_content->data() + cur_size), 1, buf_size, fin); + if (buf_len < buf_size) { + file_content->resize(file_content->size() - (buf_size - buf_len)); + fclose(fin); + return Status(); + } + } } } // namespace vkcom diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index f947f37..4ec1a1d 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -4,8 +4,6 @@ #include #include -#include "third_party/flat_hash_map/flat_hash_map.h" - namespace vkcom { const std::string UNK_TOKEN = ""; @@ -51,50 +49,9 @@ struct SpecialTokens { uint64_t n_special_tokens() const; }; -struct BPE_Rule { - // x + y -> z - uint32_t x{0}; - uint32_t y{0}; - uint32_t z{0}; - - BPE_Rule() = default; - - BPE_Rule(uint32_t x, uint32_t y, uint32_t z); - - bool operator==(const BPE_Rule &other) const; -}; - -struct BpeConfig { - double character_coverage = 1; - int n_threads = 0; - SpecialTokens special_tokens; - - BpeConfig() = default; - - BpeConfig(double character_coverage, int n_threads, - const SpecialTokens &special_tokens); -}; - -struct BPEState { - flat_hash_map char2id; - std::vector rules; - SpecialTokens special_tokens; - - void dump(const std::string &file_name); - - Status load(const std::string &file_name); -}; - -struct EncodingConfig { - bool bos; - bool eos; - bool reverse; - double dropout_prob; -}; - std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed); -std::string read_file(const std::string& path); +Status fast_read_file_utf8(const std::string &file_name, std::string *file_content); template void write_to_stdout(const std::vector &items, bool flush) { From f0ba916c40a7ec9df56012b9bbb5fcee4f35d6ed Mon Sep 17 00:00:00 2001 From: Gleb Koveshnikov Date: Sat, 1 Apr 2023 14:46:33 +0400 Subject: [PATCH 04/10] format --- youtokentome/cpp/bpe.h | 50 +++++++++------- youtokentome/cpp/utf8.cpp | 50 +++++++--------- youtokentome/cpp/utf8.h | 14 ++--- youtokentome/cpp/utils.cpp | 22 +++---- youtokentome/cpp/utils.h | 106 ++++++++++++++++----------------- youtokentome/cpp/wordpiece.cpp | 56 +++++++++-------- 6 files changed, 146 insertions(+), 152 deletions(-) diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index 70094f8..65296cf 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -31,8 +31,7 @@ struct BpeConfig { BpeConfig() = default; - BpeConfig(double character_coverage, int n_threads, - const SpecialTokens &special_tokens); + BpeConfig(double character_coverage, int n_threads, const SpecialTokens &special_tokens); }; struct BPEState { @@ -52,8 +51,10 @@ struct EncodingConfig { double dropout_prob; }; -Status train_bpe(const std::string &input_path, const std::string &model_path, - int vocab_size, BpeConfig config); +Status train_bpe(const std::string &input_path, + const std::string &model_path, + int vocab_size, + BpeConfig config); Status learn_bpe_from_string(std::string &text_utf8, int n_tokens, @@ -82,15 +83,19 @@ class BaseEncoder { void fill_from_state(); - Status encode_as_ids( - const std::vector &sentences, std::vector> *ids, bool bos = false, - bool eos = false, bool reverse = false, double dropout_prob=0) const; + Status encode_as_ids(const std::vector &sentences, + std::vector> *ids, + bool bos = false, + bool eos = false, + bool reverse = false, + double dropout_prob = 0) const; - Status encode_as_subwords( - const std::vector &sentences, - std::vector> *subwords, - bool bos = false, - bool eos = false, bool reverse = false, double dropout_prob=0) const; + Status encode_as_subwords(const std::vector &sentences, + std::vector> *subwords, + bool bos = false, + bool eos = false, + bool reverse = false, + double dropout_prob = 0) const; Status id_to_subword(int id, std::string *subword, bool replace_space = false) const; @@ -100,7 +105,9 @@ class BaseEncoder { std::vector *sentences, const std::unordered_set *ignore_ids) const; - Status decode(const std::vector &ids, std::string *sentence, const std::unordered_set *ignore_ids) const; + Status decode(const std::vector &ids, + std::string *sentence, + const std::unordered_set *ignore_ids) const; Status decode(const std::vector &ids, std::vector *sentences, @@ -110,8 +117,12 @@ class BaseEncoder { std::vector vocabulary() const; - Status encode_cli(const std::string &output_type, bool stream, bool bos = false, - bool eos = false, bool reverse = false, double dropout_prob = 0) const; + Status encode_cli(const std::string &output_type, + bool stream, + bool bos = false, + bool eos = false, + bool reverse = false, + double dropout_prob = 0) const; Status decode_cli(const std::unordered_set *ignore_ids) const; @@ -122,11 +133,10 @@ class BaseEncoder { const EncodingConfig &encoding_config, OutputType output_type) const; - Status encode_parallel( - const std::vector &sentences, - const EncodingConfig &encoding_config, OutputType output_type, - std::vector *decoder_results - ) const; + Status encode_parallel(const std::vector &sentences, + const EncodingConfig &encoding_config, + OutputType output_type, + std::vector *decoder_results) const; }; } // namespace vkcom diff --git a/youtokentome/cpp/utf8.cpp b/youtokentome/cpp/utf8.cpp index 7334763..02f05a5 100644 --- a/youtokentome/cpp/utf8.cpp +++ b/youtokentome/cpp/utf8.cpp @@ -4,31 +4,29 @@ namespace vkcom { bool is_space(uint32_t ch) { - return (ch < 256 && std::isspace(static_cast(ch))) || (ch == SPACE_TOKEN); + return (ch < 256 && std::isspace(static_cast(ch))) || (ch == SPACE_TOKEN); } bool is_punctuation(uint32_t ch) { - return (ch < 256 && std::ispunct(static_cast(ch))) || ch == 183 || ch == 171 - || ch == 187 || ch == 8249 || ch == 8250 || (8208 <= ch && ch <= 8248); + return (ch < 256 && std::ispunct(static_cast(ch))) || ch == 183 || ch == 171 + || ch == 187 || ch == 8249 || ch == 8250 || (8208 <= ch && ch <= 8248); } bool is_chinese(uint32_t ch) { - if ((ch >= 0x4E00 && ch <= 0x9FFF) || (ch >= 0x3400 && ch <= 0x4DBF) - || (ch >= 0x20000 && ch <= 0x2A6DF) || (ch >= 0x2A700 && ch <= 0x2B73F) - || (ch >= 0x2B740 && ch <= 0x2B81F) || (ch >= 0x2B820 && ch <= 0x2CEAF) - || (ch >= 0xF900 && ch <= 0xFAFF) || (ch >= 0x2F800 && ch <= 0x2FA1F)) { - return true; - } - return false; + if ((ch >= 0x4E00 && ch <= 0x9FFF) || (ch >= 0x3400 && ch <= 0x4DBF) + || (ch >= 0x20000 && ch <= 0x2A6DF) || (ch >= 0x2A700 && ch <= 0x2B73F) + || (ch >= 0x2B740 && ch <= 0x2B81F) || (ch >= 0x2B820 && ch <= 0x2CEAF) + || (ch >= 0xF900 && ch <= 0xFAFF) || (ch >= 0x2F800 && ch <= 0x2FA1F)) { + return true; + } + return false; } bool is_spacing_char(uint32_t ch) { return is_space(ch) || is_punctuation(ch) || is_chinese(ch); } bool check_byte(char x) { return (static_cast(x) & 0xc0u) == 0x80u; } -bool check_codepoint(uint32_t x) { - return (x < 0xd800) || (0xdfff < x && x < 0x110000); -} +bool check_codepoint(uint32_t x) { return (x < 0xd800) || (0xdfff < x && x < 0x110000); } uint64_t utf_length(char ch) { if ((static_cast(ch) & 0x80u) == 0) { @@ -47,7 +45,7 @@ uint64_t utf_length(char ch) { return 0; } -uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) { +uint32_t chars_to_utf8(const char *begin, uint64_t size, uint64_t *utf8_len) { uint64_t length = utf_length(begin[0]); if (length == 1) { *utf8_len = 1; @@ -61,8 +59,7 @@ uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) { *utf8_len = 2; return code_point; } - } else if (size >= 3 && length == 3 && check_byte(begin[1]) && - check_byte(begin[2])) { + } else if (size >= 3 && length == 3 && check_byte(begin[1]) && check_byte(begin[2])) { code_point += (static_cast(begin[0]) & 0x0fu) << 12u; code_point += (static_cast(begin[1]) & 0x3fu) << 6u; code_point += (static_cast(begin[2]) & 0x3fu); @@ -70,8 +67,8 @@ uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) { *utf8_len = 3; return code_point; } - } else if (size >= 4 && length == 4 && check_byte(begin[1]) && - check_byte(begin[2]) && check_byte(begin[3])) { + } else if (size >= 4 && length == 4 && check_byte(begin[1]) && check_byte(begin[2]) + && check_byte(begin[3])) { code_point += (static_cast(begin[0]) & 0x07u) << 18u; code_point += (static_cast(begin[1]) & 0x3fu) << 12u; code_point += (static_cast(begin[2]) & 0x3fu) << 6u; @@ -87,9 +84,9 @@ uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) { } bool starts_with_space(const char *begin, int64_t size) { - uint64_t len = 0; - uint32_t symbol = chars_to_utf8(begin, size, &len); - return is_space(symbol); + uint64_t len = 0; + uint32_t symbol = chars_to_utf8(begin, size, &len); + return is_space(symbol); } void utf8_to_chars(uint32_t x, std::back_insert_iterator it) { @@ -119,7 +116,7 @@ void utf8_to_chars(uint32_t x, std::back_insert_iterator it) { *(it++) = 0x80u | (x & 0x3fu); } -std::string encode_utf8(const std::vector& text) { +std::string encode_utf8(const std::vector &text) { std::string utf8_text; for (const uint32_t c : text) { utf8_to_chars(c, std::back_inserter(utf8_text)); @@ -127,7 +124,7 @@ std::string encode_utf8(const std::vector& text) { return utf8_text; } -std::vector decode_utf8(const char* begin, const char* end) { +std::vector decode_utf8(const char *begin, const char *end) { std::vector decoded_text; uint64_t utf8_len = 0; bool invalid_input = false; @@ -140,14 +137,13 @@ std::vector decode_utf8(const char* begin, const char* end) { } } if (invalid_input) { - std::cerr << "WARNING Input contains invalid unicode characters." - << std::endl; + std::cerr << "WARNING Input contains invalid unicode characters." << std::endl; } return decoded_text; } -std::vector decode_utf8(const std::string& utf8_text) { +std::vector decode_utf8(const std::string &utf8_text) { return decode_utf8(utf8_text.data(), utf8_text.data() + utf8_text.size()); } -} // namespace vkcom +} // namespace vkcom diff --git a/youtokentome/cpp/utf8.h b/youtokentome/cpp/utf8.h index 0358c3a..e3a7c72 100644 --- a/youtokentome/cpp/utf8.h +++ b/youtokentome/cpp/utf8.h @@ -27,7 +27,7 @@ bool check_codepoint(uint32_t x); uint64_t utf_length(char ch); -uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len); +uint32_t chars_to_utf8(const char *begin, uint64_t size, uint64_t *utf8_len); bool starts_with_space(const char *begin, int64_t size); @@ -40,7 +40,7 @@ std::vector decode_utf8(const char *begin, const char *end); std::vector decode_utf8(const std::string &utf8_text); struct UTF8Iterator { - UTF8Iterator(char* begin, char* end): begin(begin), end(end) {} + UTF8Iterator(char *begin, char *end) : begin(begin), end(end) {} UTF8Iterator operator++() { if (!state) { @@ -58,20 +58,16 @@ struct UTF8Iterator { return code_point; } - char* get_ptr() { - return begin; - } + char *get_ptr() { return begin; } - uint64_t get_utf8_len() { - return utf8_len; - } + uint64_t get_utf8_len() { return utf8_len; } bool empty() { assert(begin <= end); return begin == end; } -private: + private: char *begin, *end; uint32_t code_point = 0; uint64_t utf8_len = 0; diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index e7de2e1..79fdc94 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -8,21 +8,15 @@ namespace vkcom { Status::Status(int code, std::string message) : code(code), message(std::move(message)) {} -const std::string &Status::error_message() const { - return message; -} -bool Status::ok() const { - return code == 0; -} +const std::string &Status::error_message() const { return message; } + +bool Status::ok() const { return code == 0; } void SpecialTokens::dump(std::ofstream &fout) { - fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id - << std::endl; + fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id << std::endl; } -void SpecialTokens::load(std::ifstream &fin) { - fin >> unk_id >> pad_id >> bos_id >> eos_id; -} +void SpecialTokens::load(std::ifstream &fin) { fin >> unk_id >> pad_id >> bos_id >> eos_id; } uint32_t SpecialTokens::max_id() const { int ret = 0; @@ -47,7 +41,7 @@ uint64_t SpecialTokens::n_special_tokens() const { } SpecialTokens::SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id) - : pad_id(pad_id), unk_id(unk_id), bos_id(bos_id), eos_id(eos_id) {} + : pad_id(pad_id), unk_id(unk_id), bos_id(bos_id), eos_id(eos_id) {} std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) { std::vector sentences; @@ -70,7 +64,7 @@ Status fast_read_file_utf8(const std::string &file_name, std::string *file_conte while (true) { uint64_t cur_size = file_content->size(); file_content->resize(cur_size + buf_size); - int buf_len = fread((void *) (file_content->data() + cur_size), 1, buf_size, fin); + int buf_len = fread((void *)(file_content->data() + cur_size), 1, buf_size, fin); if (buf_len < buf_size) { file_content->resize(file_content->size() - (buf_size - buf_len)); fclose(fin); @@ -79,4 +73,4 @@ Status fast_read_file_utf8(const std::string &file_name, std::string *file_conte } } -} // namespace vkcom +} // namespace vkcom diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index 4ec1a1d..bad231b 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -53,7 +53,7 @@ std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *p Status fast_read_file_utf8(const std::string &file_name, std::string *file_content); -template +template void write_to_stdout(const std::vector &items, bool flush) { for (const auto &item : items) { std::cout << item << " "; @@ -64,7 +64,7 @@ void write_to_stdout(const std::vector &items, bool flush) { } } -template +template void write_to_stdout(const std::vector> &sentences, bool flush) { for (const auto &sentence : sentences) { write_to_stdout(sentence, false); @@ -77,68 +77,68 @@ void write_to_stdout(const std::vector> &sentences, bool flush) { class VectorSegmentBuilder; struct VectorSegment { - private: - friend class VectorSegmentBuilder; - - const uint32_t *begin_; - const uint32_t *end_; - const uint64_t hash_; - - VectorSegment(const uint32_t *begin, const uint32_t *end, uint64_t hash) - : begin_(begin), end_(end), hash_(hash) {} - - public: - bool operator==(const VectorSegment &other) const { - if (other.hash() != hash() || end_ - begin_ != other.end_ - other.begin_) { - return false; - } - for (auto it = begin_, other_it = other.begin_; it != end_; it++, other_it++) { - if (*it != *other_it) { - return false; - } - } - return true; + private: + friend class VectorSegmentBuilder; + + const uint32_t *begin_; + const uint32_t *end_; + const uint64_t hash_; + + VectorSegment(const uint32_t *begin, const uint32_t *end, uint64_t hash) + : begin_(begin), end_(end), hash_(hash) {} + + public: + bool operator==(const VectorSegment &other) const { + if (other.hash() != hash() || end_ - begin_ != other.end_ - other.begin_) { + return false; + } + for (auto it = begin_, other_it = other.begin_; it != end_; it++, other_it++) { + if (*it != *other_it) { + return false; + } } + return true; + } - uint64_t hash() const { return hash_; } + uint64_t hash() const { return hash_; } }; class VectorSegmentBuilder { - private: - constexpr static uint64_t MOD = 2032191299; - constexpr static uint64_t P = 726328703; - - const uint32_t *begin_; - const uint32_t *end_; - std::vector prefix_hash_; - - public: - VectorSegmentBuilder(const std::vector &segment) - : VectorSegmentBuilder(segment.data(), segment.data() + segment.size()) {} - - VectorSegmentBuilder(const uint32_t *begin, const uint32_t *end) : begin_(begin), end_(end) { - uint64_t hash = 0; - prefix_hash_.reserve(static_cast(end - begin)); - for (const uint32_t *it = begin_; it != end_; it++) { - hash = (hash * P + *it) % MOD; - prefix_hash_.push_back(hash); - } + private: + constexpr static uint64_t MOD = 2032191299; + constexpr static uint64_t P = 726328703; + + const uint32_t *begin_; + const uint32_t *end_; + std::vector prefix_hash_; + + public: + VectorSegmentBuilder(const std::vector &segment) + : VectorSegmentBuilder(segment.data(), segment.data() + segment.size()) {} + + VectorSegmentBuilder(const uint32_t *begin, const uint32_t *end) : begin_(begin), end_(end) { + uint64_t hash = 0; + prefix_hash_.reserve(static_cast(end - begin)); + for (const uint32_t *it = begin_; it != end_; it++) { + hash = (hash * P + *it) % MOD; + prefix_hash_.push_back(hash); } + } - VectorSegment finish() const { return VectorSegment(begin_, end_, hash()); } + VectorSegment finish() const { return VectorSegment(begin_, end_, hash()); } - size_t size() const { return prefix_hash_.size(); } + size_t size() const { return prefix_hash_.size(); } - bool empty() const { return prefix_hash_.empty(); } + bool empty() const { return prefix_hash_.empty(); } - uint64_t hash() const { return prefix_hash_.empty() ? 0 : prefix_hash_.back(); } + uint64_t hash() const { return prefix_hash_.empty() ? 0 : prefix_hash_.back(); } - void pop_back() noexcept { - if (!prefix_hash_.empty()) { - prefix_hash_.pop_back(); - --end_; - } + void pop_back() noexcept { + if (!prefix_hash_.empty()) { + prefix_hash_.pop_back(); + --end_; } + } }; } // namespace vkcom @@ -147,7 +147,7 @@ namespace std { template <> struct hash { - uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash(); } + uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash(); } }; } // namespace std \ No newline at end of file diff --git a/youtokentome/cpp/wordpiece.cpp b/youtokentome/cpp/wordpiece.cpp index 1d05b82..fa61806 100644 --- a/youtokentome/cpp/wordpiece.cpp +++ b/youtokentome/cpp/wordpiece.cpp @@ -15,31 +15,32 @@ namespace { struct WordPieceToken { explicit WordPieceToken(const std::string &encoded_word) - : is_prefix(true), is_special(false), is_malformed(false), word(vkcom::decode_utf8(encoded_word)) { - if (isSuffixVocab(word)) { - is_prefix = false; - word.erase(word.begin(), word.begin() + 2); - } else if (isSpecialToken(word)) { - is_special = true; - } + : is_prefix(true), is_special(false), is_malformed(false), + word(vkcom::decode_utf8(encoded_word)) { + if (isSuffixVocab(word)) { + is_prefix = false; + word.erase(word.begin(), word.begin() + 2); + } else if (isSpecialToken(word)) { + is_special = true; + } - bool all_punctuation = true; - for (uint32_t code_point : word) { - if (code_point == vkcom::INVALID_UNICODE) { - is_malformed = true; + bool all_punctuation = true; + for (uint32_t code_point : word) { + if (code_point == vkcom::INVALID_UNICODE) { + is_malformed = true; + } + if (!vkcom::is_punctuation(code_point) && !vkcom::is_space(code_point)) { + all_punctuation = false; + } } - if (!vkcom::is_punctuation(code_point) && !vkcom::is_space(code_point)) { - all_punctuation = false; + if (word.empty()) { + throw std::runtime_error("Vocab word is empty"); + } + if (is_malformed || (all_punctuation && word.size() > 1)) { + is_malformed = true; + std::cerr << "Vocab word is malformed: " << encoded_word << std::endl; } } - if (word.empty()) { - throw std::runtime_error("Vocab word is empty"); - } - if (is_malformed || (all_punctuation && word.size() > 1)) { - is_malformed = true; - std::cerr << "Vocab word is malformed: " << encoded_word << std::endl; - } -} bool is_prefix; bool is_special; @@ -162,8 +163,7 @@ std::vector encodeWordPieceImpl(const std::vector &text, if (text.size() < 2 * kWorkBatch) { token_ids = worker(0, text.size()); } else { - const size_t thread_count - = std::min(globalThreadPool().maxThreads(), text.size() / kWorkBatch); + const size_t thread_count = std::min(globalThreadPool().maxThreads(), text.size() / kWorkBatch); const size_t work_batch = text.size() / thread_count + 1; std::vector> per_thread_token_ids(thread_count); size_t work_begin = 0; @@ -172,10 +172,9 @@ std::vector encodeWordPieceImpl(const std::vector &text, while (work_end < text.size() && !vkcom::is_space(text[work_end])) { ++work_end; } - globalThreadPool().submit( - [thread_id, work_begin, work_end, &per_thread_token_ids, &worker] { - per_thread_token_ids[thread_id] = worker(work_begin, work_end); - }); + globalThreadPool().submit([thread_id, work_begin, work_end, &per_thread_token_ids, &worker] { + per_thread_token_ids[thread_id] = worker(work_begin, work_end); + }); work_begin = work_end; } @@ -199,8 +198,7 @@ std::vector encodeWordPieceImpl(const std::vector &text, return token_ids; } -std::vector -encodeWordPiece(const char *text, size_t size, const WordPieceVocabulary &vocab) { +std::vector encodeWordPiece(const char *text, size_t size, const WordPieceVocabulary &vocab) { if (size == 0) { return {}; } From 49a1dfe67d6cfe4496ffcc7f49b9d107302f35bb Mon Sep 17 00:00:00 2001 From: Gleb Koveshnikov Date: Sat, 1 Apr 2023 16:48:40 +0400 Subject: [PATCH 05/10] draft python interface --- README.md | 50 ++++---- tests/unit_tests/bpe/stress_test.cpp | 6 +- youtokentome/cpp/bpe.cpp | 15 ++- youtokentome/cpp/bpe.h | 4 +- youtokentome/cpp/utils.h | 5 - youtokentome/cpp/wordpiece.cpp | 170 ++++++++++++++++++++------- youtokentome/cpp/wordpiece.h | 17 ++- youtokentome/cpp/yttm.pyx | 34 +++--- youtokentome/yttm_cli.py | 16 +-- 9 files changed, 211 insertions(+), 106 deletions(-) diff --git a/README.md b/README.md index a9a2f73..32fa74c 100644 --- a/README.md +++ b/README.md @@ -219,16 +219,16 @@ Options: --help Show this message and exit. Commands: - bpe Train BPE model. - decode Decode ids to text. - encode Encode text to ids or subwords. - vocab Print list of learned subwords. + bpe-train Train BPE model. + bpe-decode Decode ids to text. + bpe-encode Encode text to ids or subwords. + bpe-vocab Print list of learned subwords. ``` Command `bpe` allows you to train Byte Pair Encoding model based on a text file. ``` -$ yttm bpe --help +$ yttm bpe-train --help Usage: yttm bpe [OPTIONS] @@ -247,6 +247,20 @@ Options: --help Show this message and exit. ``` +Convert ids back to text. Use `stdin` for input and `stdout` for output. + +``` +$ yttm bpe-decode --help + +Usage: yttm decode [OPTIONS] + + Decode ids to text. + +Options: + --model PATH Path to file with learned model. [required] + --ignore_ids List of indices to ignore for decoding. Example: --ignore_ids=1,2,3 + --help Show this message and exit. +``` Apply BPE encoding for a corpus of sentences. Use `stdin` for input and `stdout` for output. @@ -256,9 +270,8 @@ By default, encoding works in parallel using `n_threads` threads. Number of thre With the `--stream` option, `--n_threads` will be ignored and all sentences will be processed one by one. Each sentence will be tokenized and written to the `stdout` before the next sentence is read. - ``` -$ yttm encode --help +$ yttm bpe-encode --help Usage: yttm encode [OPTIONS] @@ -279,7 +292,7 @@ Options: Print vocabulary. This can be useful for understanding the model. ``` -$ yttm vocab --help +$ yttm bpe-vocab --help Usage: yttm vocab [OPTIONS] @@ -291,26 +304,11 @@ Options: --help Show this message and exit. ``` -Convert ids back to text. Use `stdin` for input and `stdout` for output. - -``` -$ yttm decode --help - -Usage: yttm decode [OPTIONS] - - Decode ids to text. - -Options: - --model PATH Path to file with learned model. [required] - --ignore_ids List of indices to ignore for decoding. Example: --ignore_ids=1,2,3 - --help Show this message and exit. -``` - ### Examples TODO: wordpiece ```bash -$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 -$ yttm encode --model OUTPUT_MODEL_FILE --output_type subword < TEST_DATA_FILE > ENCODED_DATA -``` \ No newline at end of file +$ yttm bpe-train --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 +$ yttm bpe-encode --model OUTPUT_MODEL_FILE --output_type subword < TEST_DATA_FILE > ENCODED_DATA +``` diff --git a/tests/unit_tests/bpe/stress_test.cpp b/tests/unit_tests/bpe/stress_test.cpp index db4b2a2..32e8a8c 100644 --- a/tests/unit_tests/bpe/stress_test.cpp +++ b/tests/unit_tests/bpe/stress_test.cpp @@ -320,7 +320,7 @@ void manual_test() { BpeConfig bpe_config = {1.0, 1, special_tokens_config}; BPEState model_fast; - status = learn_bpe_from_string(trn_data_copy, n_tokens, "remove_it.txt", bpe_config, &model_fast); + status = bpe_learn_from_string(trn_data_copy, n_tokens, "remove_it.txt", bpe_config, &model_fast); assert(status.ok()); auto model_slow = learn_bpe_slow(trn_data, n_tokens, "remove_it.txt", bpe_config); assert(model_fast.rules == model_slow.rules); @@ -369,7 +369,7 @@ void parallel_test(int n_iter, int n_threads) { auto train_data_copy = train_data; BpeConfig bpe_config = {character_coverage, n_threads, {0, 1, 2, 3}}; BPEState learned_model; - status = learn_bpe_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config, &learned_model); + status = bpe_learn_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config, &learned_model); assert(status.ok()); BaseEncoder applyer(learned_model, 20); @@ -412,7 +412,7 @@ void base_stress(int n_iter) { auto train_data_copy = train_data; BpeConfig bpe_config = {character_coverage, n_threads, {0, 1, 2, 3}}; BPEState fast_solution_model; - status = learn_bpe_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config, &fast_solution_model); + status = bpe_learn_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config, &fast_solution_model); assert(status.ok()); auto slow_solution_model = learn_bpe_slow(train_data, vocab_size, "remove_it.txt", bpe_config); diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index 52a27b3..5d625c7 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -22,6 +22,15 @@ #include "utf8.h" #include "utils.h" +namespace { + +const std::string UNK_TOKEN = ""; +const std::string PAD_TOKEN = ""; +const std::string BOS_TOKEN = ""; +const std::string EOS_TOKEN = ""; + +} // namespace + namespace vkcom { bool BPE_Rule::operator==(const BPE_Rule &other) const { @@ -853,7 +862,7 @@ uint64_t compute_char_count(flat_hash_map& char_cnt, char* b return char_count; } -Status learn_bpe_from_string(std::string &text_utf8, int n_tokens, +Status bpe_learn_from_string(std::string &text_utf8, int n_tokens, const std::string &output_file, BpeConfig bpe_config, BPEState *bpe_state) { assert(bpe_config.n_threads >= 1 || bpe_config.n_threads == -1); @@ -1362,7 +1371,7 @@ void print_config(const std::string &input_path, const std::string &model_path, std::cerr << std::endl; } -Status train_bpe(const std::string &input_path, const std::string &model_path, +Status bpe_train(const std::string &input_path, const std::string &model_path, int vocab_size, BpeConfig bpe_config) { Status status = check_config(bpe_config, vocab_size); if (!status.ok()) { @@ -1377,7 +1386,7 @@ Status train_bpe(const std::string &input_path, const std::string &model_path, } std::cerr << "learning bpe..." << std::endl; BPEState bpe_state; - status = learn_bpe_from_string(data, vocab_size, model_path, bpe_config, &bpe_state); + status = bpe_learn_from_string(data, vocab_size, model_path, bpe_config, &bpe_state); if (!status.ok()) { return status; } diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index 65296cf..e40ffc0 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -51,12 +51,12 @@ struct EncodingConfig { double dropout_prob; }; -Status train_bpe(const std::string &input_path, +Status bpe_train(const std::string &input_path, const std::string &model_path, int vocab_size, BpeConfig config); -Status learn_bpe_from_string(std::string &text_utf8, +Status bpe_learn_from_string(std::string &text_utf8, int n_tokens, const std::string &output_file, BpeConfig bpe_config, diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index bad231b..8010e74 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -6,11 +6,6 @@ namespace vkcom { -const std::string UNK_TOKEN = ""; -const std::string PAD_TOKEN = ""; -const std::string BOS_TOKEN = ""; -const std::string EOS_TOKEN = ""; - enum OutputType { ID, SUBWORD }; struct DecodeResult { diff --git a/youtokentome/cpp/wordpiece.cpp b/youtokentome/cpp/wordpiece.cpp index fa61806..3713b2b 100644 --- a/youtokentome/cpp/wordpiece.cpp +++ b/youtokentome/cpp/wordpiece.cpp @@ -13,6 +13,11 @@ namespace { +const std::string UNK_TOKEN = "[UNK]"; +const std::string PAD_TOKEN = "[PAD]"; +const std::string BOS_TOKEN = "[BOS]"; +const std::string EOS_TOKEN = "[EOS]"; + struct WordPieceToken { explicit WordPieceToken(const std::string &encoded_word) : is_prefix(true), is_special(false), is_malformed(false), @@ -49,35 +54,50 @@ struct WordPieceToken { }; struct WordPieceVocabulary { - static constexpr int kDefaultUnkTokenId = -1; - - std::vector tokens; - int unk_token_id = kDefaultUnkTokenId; -}; + explicit WordPieceVocabulary(const std::vector& words) { + tokens.reserve(words.size()); + int token_id = 0; + for (const std::string& word : words) { + update_special_tokens(word, token_id); + WordPieceToken token(word); + tokens.push_back(std::move(token)); + ++token_id; + } + } -WordPieceVocabulary readVocabFromFile(const std::string &file) { - WordPieceVocabulary vocab_utf8; - std::ifstream fin(file); - std::string word; - int token_id = 0; - while (std::getline(fin, word)) { - if (word == kUnkTokenIdStr) { - vocab_utf8.unk_token_id = token_id; - } - WordPieceToken token(word); - vocab_utf8.tokens.push_back(std::move(token)); - ++token_id; + explicit WordPieceVocabulary(const std::string &file) { + WordPieceVocabulary vocab_utf8; + std::ifstream fin(file); + std::string word; + int token_id = 0; + while (std::getline(fin, word)) { + update_special_tokens(word, token_id); + WordPieceToken token(word); + tokens.push_back(std::move(token)); + ++token_id; + } } - return vocab_utf8; -} -vkcom::ThreadPool &globalThreadPool(size_t n_threads) { - static vkcom::ThreadPool thread_pool(n_threads); - return thread_pool; -} + std::vector tokens; + SpecialTokens special_tokens; + +private: + void update_special_tokens(const std::string& word, int token_id) { + if (word == UNK_TOKEN) { + special_tokens.unk_id = token_id; + } else if (word == PAD_TOKEN) { + special_tokens.pad_id = token_id; + } else if (word == BOS_TOKEN) { + special_tokens.bos_id = token_id; + } else if (word == EOS_TOKEN) { + special_tokens.eos_id = token_id; + } + } +}; -std::vector encodeWordPieceImpl(const std::vector &text, - const WordPieceVocabulary &vocab) { +std::vector encode_word_piece_impl(const std::vector &text, + const WordPieceVocabulary &vocab, + vkcom::ThreadPool& thread_pool) { using WordMap = std::unordered_map; WordMap prefix_to_id; // no ## in word prefix WordMap suffix_to_id; // ## in word prefix @@ -100,7 +120,7 @@ std::vector encodeWordPieceImpl(const std::vector &text, || vkcom::is_spacing_char(text[index - 1]); }; - const auto worker = [&, unk_token_id = vocab.unk_token_id](size_t begin, size_t end) { + const auto worker = [&, unk_token_id = vocab.special_tokens.unk_id](size_t begin, size_t end) { std::vector token_ids; token_ids.reserve((end - begin) / max_len + 1); @@ -163,7 +183,7 @@ std::vector encodeWordPieceImpl(const std::vector &text, if (text.size() < 2 * kWorkBatch) { token_ids = worker(0, text.size()); } else { - const size_t thread_count = std::min(globalThreadPool().maxThreads(), text.size() / kWorkBatch); + const size_t thread_count = std::min(thread_pool.maxThreads(), text.size() / kWorkBatch); const size_t work_batch = text.size() / thread_count + 1; std::vector> per_thread_token_ids(thread_count); size_t work_begin = 0; @@ -172,13 +192,13 @@ std::vector encodeWordPieceImpl(const std::vector &text, while (work_end < text.size() && !vkcom::is_space(text[work_end])) { ++work_end; } - globalThreadPool().submit([thread_id, work_begin, work_end, &per_thread_token_ids, &worker] { + thread_pool.submit([thread_id, work_begin, work_end, &per_thread_token_ids, &worker] { per_thread_token_ids[thread_id] = worker(work_begin, work_end); }); work_begin = work_end; } - globalThreadPool().waitCompletion(); + thread_pool.waitCompletion(); size_t token_count = 0; for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { @@ -198,35 +218,103 @@ std::vector encodeWordPieceImpl(const std::vector &text, return token_ids; } -std::vector encodeWordPiece(const char *text, size_t size, const WordPieceVocabulary &vocab) { +std::vector encode_word_piece(const char *text, size_t size, const WordPieceVocabulary &vocab) { if (size == 0) { return {}; } - const std::vector text_utf8 = utils::parseText(text, size, globalThreadPool()); - return encodeWordPieceImpl(text_utf8, vocab); + vkcom::ThreadPool thread_pool(0); + const std::vector text_utf8 = utils::parseText(text, size, thread_pool); + return encode_word_piece_impl(text_utf8, vocab, thread_pool); } } // namespace namespace vkcom::wordpiece { -/*std::vector encode_wordpiece(const std::string& input_path, const std::string& vocab_path) { - const WordPieceVocabulary vocab_utf8 = readVocabFromFile(vocab_file); - // TODO: use mapped file - std::string text = read_file(input_path); - return encodeWordPiece(text.data(), text.size(), vocab_utf8); +Status encode_as_ids(const std::string &text_path, + const std::string& vocab_path, std::vector *ids) { + const uint64_t batch_limit = 10 * 1024 * 1024; + try { + std::string text; + Status status = fast_read_file_utf8(text_path, &text); + if (!status.ok()) { + return status; + } + uint64_t processed = 0; + std::vector vocab = read_lines_from_stdin(batch_limit, &processed); + return encode_as_ids(text, vocab, ids); + } catch (const std::exception& ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } } -Status encode_wordpiece_cli(const std::string& input_path, const std::string& vocab_path) { +Status encode_as_ids(const std::string &text, + const std::vector& vocab, std::vector *ids) { try { - std::vector ids = encode_wordpiece(input_path, vocab_path); - write_to_stdout(ids, true); + WordPieceVocabulary word_piece_vocab(vocab); + *ids = encode_word_piece(text.data(), text.size(), word_piece_vocab); return Status(); } catch (const std::exception& ex) { return Status(1, ex.what()); } catch (...) { return Status(1, "Unknown error"); } -}*/ +} + +Status encode_as_subwords(const std::string &text_path, + const std::string& vocab_path, + std::vector *subwords) { + try { + std::string text; + Status status = fast_read_file_utf8(text_path, &text); + if (!status.ok()) { + return status; + } + uint64_t processed = 0; + std::vector vocab = read_lines_from_stdin(batch_limit, &processed); + return encode_as_subwords(text, vocab, subwords); + } catch (const std::exception& ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } +} + +Status encode_as_subwords(const std::string &text, + const std::vector& vocab, + std::vector *subwords) { + try { + WordPieceVocabulary word_piece_vocab(vocab); + std::vector ids = encode_word_piece(text.data(), text.size(), word_piece_vocab); + for (int id : ids) { + subwords->push_back(vocab[id]); + } + return Status(); + } catch (const std::exception& ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } +} + +Status decode(const std::vector& ids, + const std::vector& vocab, + std::vector *subwords, + const std::unordered_set *ignore_ids) { + try { + for (int id : ids) { + if (!ignore_ids || ignore_ids->count(id) == 0) { + subwords->push_back(vocab[id]); + } + } + return Status(); + } catch (const std::exception& ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } +} } // namespace vkcom::wordpiece \ No newline at end of file diff --git a/youtokentome/cpp/wordpiece.h b/youtokentome/cpp/wordpiece.h index 98375d5..92bc1b8 100644 --- a/youtokentome/cpp/wordpiece.h +++ b/youtokentome/cpp/wordpiece.h @@ -2,16 +2,27 @@ #include #include +#include namespace vkcom::wordpiece { -/*Status encode_as_ids(const std::string &text, +Status encode_as_ids(const std::string &text_path, + const std::string& vocab_path, std::vector *ids); + +Status encode_as_ids(const std::string &text, const std::vector& vocab, std::vector *ids); +Status encode_as_subwords(const std::string &text_path, + const std::string& vocab_path, + std::vector *subwords); + Status encode_as_subwords(const std::string &text, const std::vector& vocab, - std::vector> *subwords); + std::vector *subwords); -Status encode_wordpiece_cli(const std::string& vocab_path);*/ +Status decode(const std::vector& ids, + const std::vector& vocab, + std::vector *subwords, + const std::unordered_set *ignore_ids); } // namespace vkcom::wordpiece \ No newline at end of file diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index 3bddb96..a39b241 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -2,7 +2,6 @@ from libcpp.vector cimport vector from libcpp.unordered_set cimport unordered_set from libcpp.string cimport string from libcpp cimport bool -import os from pathlib import Path from typing import Collection @@ -25,31 +24,29 @@ cdef extern from "utils.h" namespace "vkcom": cdef extern from "bpe.h" namespace "vkcom": - Status train_bpe(const string &source_path, const string& model_path, int vocab_size, const BpeConfig& bpe_config) + Status bpe_train(const string &source_path, const string &model_path, int vocab_size, const BpeConfig &bpe_config) cdef extern from "bpe.h" namespace "vkcom": cdef cppclass BaseEncoder: - BaseEncoder(const string& model_path, int n_threads, Status* status) - - Status encode_as_ids(const vector[string] &sentences, vector[vector[int]]* ids, bool bos, bool eos, bool reverse, double dropout_prob) const - Status encode_as_subwords(const vector[string]& sentences, vector[vector[string]]* subwords, bool bos, bool eos, bool reverse, double dropout_prob) const + BaseEncoder(const string &model_path, int n_threads, Status *status) + Status encode_as_ids(const vector[string] &sentences, vector[vector[int]] *ids, bool bos, bool eos, bool reverse, double dropout_prob) const + Status encode_as_subwords(const vector[string] &sentences, vector[vector[string]] *subwords, bool bos, bool eos, bool reverse, double dropout_prob) const Status encode_cli(string output_type, bool stream, bool bos, bool eos, bool reverse, double dropout_prob) const - Status decode_cli(const unordered_set[int]* ignore_ids) const - - void vocab_cli(bool verbose) const - - Status id_to_subword(int id, string* subword) const + Status decode(const vector[vector[int]] &ids, vector[string] *output, const unordered_set[int] *ignore_ids) const + Status decode_cli(const unordered_set[int] *ignore_ids) const + Status id_to_subword(int id, string *subword) const int subword_to_id(const string &subword) const - Status decode(const vector[vector[int]]& ids, vector[string]* output, const unordered_set[int]* ignore_ids) const + int vocab_size() const vector[string] vocabulary() const + void vocab_cli(bool verbose) const cdef class BPE: - cdef BaseEncoder* encoder + cdef BaseEncoder *encoder def __dealloc__(self): del self.encoder @@ -79,7 +76,7 @@ cdef class BPE: bpe_config.special_tokens.bos_id = bos_id bpe_config.special_tokens.eos_id = eos_id - cdef Status status = train_bpe(data.encode(), model.encode(), vocab_size, bpe_config) + cdef Status status = bpe_train(data.encode(), model.encode(), vocab_size, bpe_config) if status.code != 0: raise ValueError(status.message.decode()) @@ -133,7 +130,6 @@ cdef class BPE: return subword.decode() def decode(self, ids, ignore_ids): - if not isinstance(ids, list): raise TypeError( "{} is not a list instance".format(type(ids)) @@ -179,3 +175,11 @@ cdef class BPE: def vocab_cli(self, verbose): self.encoder.vocab_cli(verbose) +cdef extern from "wordpiece.h" namespace "vkcom::wordpiece": + Status encode_as_ids(const string &text_path, const string &vocab_path, vector[int] *ids) + Status encode_as_ids(const string &text, const vector[string] &vocab, vector[int] *ids) + + Status encode_as_subwords(const string &text_path, const string &vocab_path, vector[string] *subwords) + Status encode_as_subwords(const string &text, const vector[string] &vocab, vector[string] *subwords) + + Status decode(const vector[int] &ids, const vector[string] &vocab, vector[string] *subwords, const unordered_set[int] *ignore_ids) diff --git a/youtokentome/yttm_cli.py b/youtokentome/yttm_cli.py index 7e66879..4ae8c2e 100644 --- a/youtokentome/yttm_cli.py +++ b/youtokentome/yttm_cli.py @@ -57,7 +57,7 @@ def main(): default=3, show_default=True, ) -def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id): +def bpe_train(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id): """Train BPE model.""" yttmc.BPE.train( data=data, @@ -105,7 +105,7 @@ def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eo show_default=True, help="BPE-dropout probability (the probability of a merge being dropped)", ) -def encode(model, output_type, n_threads, bos, eos, reverse, stream, dropout_prob): +def bpe_encode(model, output_type, n_threads, bos, eos, reverse, stream, dropout_prob): """Encode text to ids or subwords.""" if n_threads < -1 or n_threads == 0: raise ValueError( @@ -143,7 +143,7 @@ def validate_ignore_ids(ctx, param, value): required=False, help="List of indices to ignore for decoding. Example: --ignore_ids=1,2,3", ) -def decode(model, ignore_ids): +def bpe_decode(model, ignore_ids): """Decode ids to text.""" bpe = yttmc.BPE(model) bpe.decode_cli(ignore_ids) @@ -157,13 +157,13 @@ def decode(model, ignore_ids): help="Path to file with learned model.", ) @click.option("--verbose", is_flag=True, help="Add merging rules.") -def vocab(model, verbose): +def bpe_vocab(model, verbose): """Print list of learned subwords.""" bpe = yttmc.BPE(model) bpe.vocab_cli(verbose) -main.add_command(bpe) -main.add_command(encode) -main.add_command(decode) -main.add_command(vocab) +main.add_command(bpe_train) +main.add_command(bpe_encode) +main.add_command(bpe_decode) +main.add_command(bpe_vocab) From bb0e275f504a73b1dcd8c2f2c3c3bfcfa1f582ec Mon Sep 17 00:00:00 2001 From: Gleb Koveshnikov Date: Sun, 2 Apr 2023 13:27:42 +0400 Subject: [PATCH 06/10] better python interface --- README.md | 9 +++- tests/speed_test/Dockerfile | 4 +- tests/unit_tests/wordpiece/test_manual.py | 53 ++++++++++++++++++++++- youtokentome/cpp/yttm.pyx | 32 ++++++++++++++ 4 files changed, 92 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 32fa74c..74335a0 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ For example, the phrase ```Blazingly fast tokenization!``` can be tokenized into Algorighm properties: * Currently supports tokenizer only, but not training -* Time complexity is `O(NM^2)`, where `N` is the length of tokenized data and `M` is the max length of word in vocabulary +* Time complexity is `O(Nm^2)`, where `N` is the length of tokenized data and `m` is the max length of word in vocabulary ## Installation @@ -201,8 +201,13 @@ Convert each id to subword and concatenate with space symbol. TODO ### Methods +Class `youtokentome.WordPiece` has the following methods: -TODO +#### constructor + +#### encode + +#### decode ## Command line interface diff --git a/tests/speed_test/Dockerfile b/tests/speed_test/Dockerfile index 24b385d..9631aec 100644 --- a/tests/speed_test/Dockerfile +++ b/tests/speed_test/Dockerfile @@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ bzip2 \ perl && \ pip3 install -r requirements.txt && \ - pip3 install youtokentome + pip3 install youtokentome WORKDIR /repos @@ -38,4 +38,4 @@ COPY ./wordpiece.py ./wordpiece.py # use comma to separate langs, e.g.: "--langs", "en", "ru", "zh", "ja" CMD ["python", "bpe.py", "--langs", "ru", "--corpus_size", "10", "--vocab_size", "30000"] -CMD ["python", "bpe.py", "--langs", "ru", "--corpus_size", "10", "--vocab", "bert-base-cased.txt"] \ No newline at end of file +CMD ["python", "wordpiece.py", "--langs", "ru", "--corpus_size", "10", "--vocab", "bert-base-cased.txt"] diff --git a/tests/unit_tests/wordpiece/test_manual.py b/tests/unit_tests/wordpiece/test_manual.py index b49d2ab..3e28860 100644 --- a/tests/unit_tests/wordpiece/test_manual.py +++ b/tests/unit_tests/wordpiece/test_manual.py @@ -1,4 +1,53 @@ # -*- coding: utf-8 -*- -import os +import youtokentome as yttm -import youtokentome as yttm \ No newline at end of file + +def check(text, vocab, output_type=yttm.OutputType.ID): + encoder = yttm.WordPiece(vocab) + return encoder.encode(text, output_type=output_type) + + +def test_russian(): + ids = check("привет мир", ["привет", "мир"]) + assert ids == [0, 1] + + ids = check("привет мир", ["при", "##вет", "мир"]) + assert ids == [0, 1, 2] + + ids = check("токенизация это круто", ["ток", "крут", "это", "##за", "##ция", "ция"]) + assert ids == [-1, 2, -1] + + ids = check("токенизация это круто", ["ток", "крут", "это", "##за", "##ени", "##о", "##ция", "ция"]) + assert ids == [0, 4, 3, 6, 2, 1, 5] + + +def test_english(): + ids = check("self-made", ["self", "made", "-", "##-", "##made"]) + assert ids == [0, 2, 1] + + ids = check("self, made", ["self", "made", ",", "##,", "##made"]) + assert ids == [0, 2, 1] + + ids = check("self , made", ["self", "made", ",", "##,", "##made"]) + assert ids == [0, 2, 1] + + +def test_japanese(): + pass + + +def test_misc(): + ids = check("abcdef", ["a", "##bcdef", "ab", "##c", "##d", "##e", "##f"]) + assert ids == [2, 3, 4, 5, 6] + + ids = check("abcdef abc abcd", ["abcd", "def", "abc"]) + assert ids == [-1, 2, 0] + + ids = check("abc", ["a", "abd"]) + assert ids == [-1] + + ids = check("abc a abc abd", ["a", "abd"]) + assert ids == [-1, 0, -1, 1] + + ids = check("abcdef", ["bcde", "ac", "def", "bc", "bcdef", "##a", "##b", "##c", "##d"]) + assert ids == [-1] diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index a39b241..a7964d0 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -183,3 +183,35 @@ cdef extern from "wordpiece.h" namespace "vkcom::wordpiece": Status encode_as_subwords(const string &text, const vector[string] &vocab, vector[string] *subwords) Status decode(const vector[int] &ids, const vector[string] &vocab, vector[string] *subwords, const unordered_set[int] *ignore_ids) + +cdef class WordPiece: + def __init__(self, vocab, n_threads=0): + self.vocab = vocab + self.n_threads = n_threads + + def encode(self, text, output_type): + cdef Status status + if output_type == 'id': + cdef vector[int] ids + status = encode_as_ids(text, self.vocab, &ids) + if status.code != 0: + raise ValueError(status.message.decode()) + return ids + elif output_type == 'subword': + cdef vector[string] subwords + status = encode_as_subwords(text, self.vocab, &subwords) + if status.code != 0: + raise ValueError(status.message.decode()) + return subwords + else: + raise ValueError('output_type must be equal to "id" or "subword"') + + def decode(self, ids, ignore_ids) + if ignore_ids is None: + ignore_ids = set() + cdef unordered_set[int] c_ignore_ids = unordered_set[int](ignore_ids) + cdef vector[string] subwords + cdef Status status = decode(ids, self.vocab, &subwords, &c_ignore_ids) + if status.code != 0: + raise ValueError(status.message.decode()) + return subwords From 506fa94304b70afedc9506b34bf9736cb35395e9 Mon Sep 17 00:00:00 2001 From: Gleb Koveshnikov Date: Mon, 3 Apr 2023 04:09:24 +0800 Subject: [PATCH 07/10] binding works --- MANIFEST.in | 3 +- requirements.txt | 15 ++- setup.py | 5 +- tests/unit_tests/bpe/stress_test.cpp | 6 +- tests/unit_tests/bpe/test_cli.py | 32 +++--- tests/unit_tests/bpe/test_stress.py | 2 +- tests/unit_tests/bpe/utils_for_testing.py | 4 +- tests/unit_tests/wordpiece/test_manual.py | 13 ++- youtokentome/__init__.py | 1 + youtokentome/cpp/bpe.cpp | 14 +-- youtokentome/cpp/bpe.h | 2 +- youtokentome/cpp/utf8.cpp | 2 + youtokentome/cpp/utils.cpp | 4 +- youtokentome/cpp/utils.h | 45 ++++---- youtokentome/cpp/wordpiece.cpp | 132 +++++++++++++++------- youtokentome/cpp/wordpiece.h | 13 +-- youtokentome/cpp/yttm.pyx | 28 +++-- youtokentome/youtokentome.py | 37 ++++++ 18 files changed, 234 insertions(+), 124 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 4ade57c..27234f2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,8 +4,9 @@ include youtokentome/cpp/utf8.h include youtokentome/cpp/wordpiece.h include youtokentome/cpp/yttm.pyx include youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h +include youtokentome/cpp/third_party/flat_hash_map/LICENSE include youtokentome/cpp/third_party/thread_pool/thread_pool.h -include youtokentome/cpp/third_party/LICENSE +include youtokentome/cpp/third_party/thread_pool/LICENSE include LICENSE include README.md include requirements.txt diff --git a/requirements.txt b/requirements.txt index 16d2c18..32e09ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,10 @@ -setuptools>=32.0.0 -Click>=7.0 -pytest==4.3.1 -tabulate==0.8.5 -Cython==0.29.14 \ No newline at end of file +atomicwrites==1.4.1 +attrs==22.2.0 +click==8.1.3 +Cython==0.29.34 +more-itertools==9.1.0 +pluggy==1.0.0 +py==1.11.0 +pytest==7.2.1 +six==1.16.0 +tabulate==0.9.0 diff --git a/setup.py b/setup.py index b603d23..60abd1c 100644 --- a/setup.py +++ b/setup.py @@ -12,8 +12,9 @@ "youtokentome/cpp/bpe.cpp", "youtokentome/cpp/utils.cpp", "youtokentome/cpp/utf8.cpp", + "youtokentome/cpp/wordpiece.cpp" ], - extra_compile_args=["-std=c++17", "-pthread", "-O3"], + extra_compile_args=["-std=c++11", "-pthread", "-O3"], language="c++", ) ] @@ -35,7 +36,7 @@ python_requires=">=3.5.0", install_requires=["Click>=7.0"], entry_points={"console_scripts": ["yttm = youtokentome.yttm_cli:main"]}, - author="Ivan Belonogov", + author="VKCOM", license="MIT", classifiers=[ "License :: OSI Approved :: MIT License", diff --git a/tests/unit_tests/bpe/stress_test.cpp b/tests/unit_tests/bpe/stress_test.cpp index 32e8a8c..ce8d850 100644 --- a/tests/unit_tests/bpe/stress_test.cpp +++ b/tests/unit_tests/bpe/stress_test.cpp @@ -6,9 +6,9 @@ #include #include -#include "../../youtokentome/cpp/utils.h" -#include "../../youtokentome/cpp/bpe.h" -#include "../../youtokentome/cpp/utf8.h" +#include "../../../youtokentome/cpp/utils.h" +#include "../../../youtokentome/cpp/bpe.h" +#include "../../../youtokentome/cpp/utf8.h" namespace vkcom { diff --git a/tests/unit_tests/bpe/test_cli.py b/tests/unit_tests/bpe/test_cli.py index c9cadee..a8b7f89 100644 --- a/tests/unit_tests/bpe/test_cli.py +++ b/tests/unit_tests/bpe/test_cli.py @@ -18,7 +18,7 @@ def test_bos_eos_reverse(): generate_artifacts() cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=subword", "--n_threads=1", @@ -29,7 +29,7 @@ def test_bos_eos_reverse(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=subword", "--n_threads=1", @@ -41,7 +41,7 @@ def test_bos_eos_reverse(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=id", "--n_threads=1", @@ -52,7 +52,7 @@ def test_bos_eos_reverse(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=id", "--n_threads=1", @@ -67,11 +67,11 @@ def test_bos_eos_reverse(): def test_interactive_mode(): generate_artifacts() print("interactive helper running id ...") - cmd = f"python interactor.py | yttm encode --stream --model={BASE_MODEL_FILE} --output_type=id > log.txt" + cmd = f"python interactor.py | yttm bpe-encode --stream --model={BASE_MODEL_FILE} --output_type=id > log.txt" assert os.system(cmd) == 0 print("interactive helper running subword ...") - cmd = f"python interactor.py | yttm encode --stream --model={BASE_MODEL_FILE} --output_type=subword > log.txt" + cmd = f"python interactor.py | yttm bpe-encode --stream --model={BASE_MODEL_FILE} --output_type=subword > log.txt" assert os.system(cmd) == 0 os.remove("log.txt") @@ -80,7 +80,7 @@ def test_multithreading(): generate_artifacts() cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=subword", "--n_threads=10", @@ -92,7 +92,7 @@ def test_renaming(): generate_artifacts() cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={RENAME_ID_MODEL_FILE}", "--output_type=id", "--bos", @@ -103,7 +103,7 @@ def test_renaming(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={RENAME_ID_MODEL_FILE}", "--output_type=id", "--eos", @@ -122,7 +122,7 @@ def test_renaming_unknown(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={RENAME_ID_MODEL_FILE}", "--output_type=id", "--reverse", @@ -143,8 +143,8 @@ def test_renaming_unknown(): def test_vocab(): generate_artifacts() - run(["yttm", "vocab", f"--model={BASE_MODEL_FILE}"], check=True) - run(["yttm", "vocab", f"--model={BASE_MODEL_FILE}", "--verbose"], check=True) + run(["yttm", "bpe-vocab", f"--model={BASE_MODEL_FILE}"], check=True) + run(["yttm", "bpe-vocab", f"--model={BASE_MODEL_FILE}", "--verbose"], check=True) def test_decode(): @@ -153,7 +153,7 @@ def test_decode(): with open("decode_text_in.txt", "w") as fout: fout.write(text_in) - cmd_args = ["yttm", "encode", f"--model={BASE_MODEL_FILE}", "--output_type=id"] + cmd_args = ["yttm", "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=id"] run( cmd_args, stdin=open("decode_text_in.txt", "r"), @@ -161,7 +161,7 @@ def test_decode(): check=True, ) - cmd_args = ["yttm", "decode", f"--model={BASE_MODEL_FILE}"] + cmd_args = ["yttm", "bpe-decode", f"--model={BASE_MODEL_FILE}"] run( cmd_args, stdin=open("decode_id.txt", "r"), @@ -176,7 +176,7 @@ def test_decode(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=id", "--bos", @@ -191,7 +191,7 @@ def test_decode(): cmd_args = [ "yttm", - "decode", + "bpe-decode", f"--model={BASE_MODEL_FILE}", f"--ignore_ids={BOS_ID},{EOS_ID}", ] diff --git a/tests/unit_tests/bpe/test_stress.py b/tests/unit_tests/bpe/test_stress.py index 98a576e..92ca8b7 100644 --- a/tests/unit_tests/bpe/test_stress.py +++ b/tests/unit_tests/bpe/test_stress.py @@ -9,7 +9,7 @@ def compile_test(): if tests_compiled: return build_files = ["bpe.cpp", "utils.cpp", "utf8.cpp"] - files = ["../../youtokentome/cpp/" + file_name for file_name in build_files] + files = ["../../../youtokentome/cpp/" + file_name for file_name in build_files] files.append("stress_test.cpp") print("compiling stress test ...") diff --git a/tests/unit_tests/bpe/utils_for_testing.py b/tests/unit_tests/bpe/utils_for_testing.py index 42dba38..eceaa96 100644 --- a/tests/unit_tests/bpe/utils_for_testing.py +++ b/tests/unit_tests/bpe/utils_for_testing.py @@ -37,7 +37,7 @@ def generate_artifacts(): cmd_args = [ "yttm", - "bpe", + "bpe-train", f"--data={TRAIN_FILE}", f"--model={BASE_MODEL_FILE}", "--vocab_size=16000", @@ -49,7 +49,7 @@ def generate_artifacts(): run(cmd_args, check=True) cmd_args = [ "yttm", - "bpe", + "bpe-train", f"--data={TRAIN_FILE}", f"--model={RENAME_ID_MODEL_FILE}", "--vocab_size=16000", diff --git a/tests/unit_tests/wordpiece/test_manual.py b/tests/unit_tests/wordpiece/test_manual.py index 3e28860..fe1c77c 100644 --- a/tests/unit_tests/wordpiece/test_manual.py +++ b/tests/unit_tests/wordpiece/test_manual.py @@ -3,8 +3,17 @@ def check(text, vocab, output_type=yttm.OutputType.ID): - encoder = yttm.WordPiece(vocab) - return encoder.encode(text, output_type=output_type) + TEXT_FILE = "text_file.txt" + VOCAB_FILE = "vocab_file.txt" + with open(TEXT_FILE, 'w') as f: + f.write(text) + with open(VOCAB_FILE, 'w') as f: + for word in vocab: + f.write(word) + f.write('\n') + + encoder = yttm.WordPiece(VOCAB_FILE) + return encoder.encode(TEXT_FILE, output_type=output_type) def test_russian(): diff --git a/youtokentome/__init__.py b/youtokentome/__init__.py index a0d7baa..ab836bc 100644 --- a/youtokentome/__init__.py +++ b/youtokentome/__init__.py @@ -1,2 +1,3 @@ from .youtokentome import BPE from .youtokentome import OutputType +from .youtokentome import WordPiece diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index 5d625c7..1bc07de 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -390,10 +390,10 @@ struct WordCount { }; -flat_hash_map compute_word_count( +flat_hash_map compute_word_count( char* sbegin, char* send, const flat_hash_map &char2id) { - flat_hash_map hash2wordcnt; + flat_hash_map hash2wordcnt; std::vector word; UTF8Iterator utf8_iter(sbegin, send); @@ -405,8 +405,8 @@ flat_hash_map compute_word_count( char* begin_of_word = utf8_iter.get_ptr(); for (; !utf8_iter.empty() && !is_space(*utf8_iter); ++utf8_iter); char* end_of_word = utf8_iter.get_ptr(); - VectorSegmentBuilder word_hash_builder(begin_of_word, end_of_word); - VectorSegment word_hash = word_hash_builder.finish(); + VectorSegmentBuilder word_hash_builder(begin_of_word, end_of_word); + BpeVectorSegment word_hash = word_hash_builder.finish(); auto it = hash2wordcnt.find(word_hash); if (it == hash2wordcnt.end()) { word.clear(); @@ -889,7 +889,7 @@ Status bpe_learn_from_string(std::string &text_utf8, int n_tokens, flat_hash_set removed_chars; flat_hash_map char2id; - std::vector> hash2wordcnt(n_threads); + std::vector> hash2wordcnt(n_threads); int error_flag = 0; flat_hash_map> recipe; @@ -1047,7 +1047,7 @@ Status bpe_learn_from_string(std::string &text_utf8, int n_tokens, word_cnt_global.resize(hash2wordcnt[0].size()); std::transform( hash2wordcnt[0].begin(), hash2wordcnt[0].end(), word_cnt_global.begin(), - [](const std::pair &x) { return x.second; }); + [](const std::pair &x) { return x.second; }); hash2wordcnt.shrink_to_fit(); text_utf8.shrink_to_fit(); @@ -1986,7 +1986,7 @@ Status BaseEncoder::encode_cli(const std::string &output_type_str, bool stream, int chars_remove = 0; do { processed = 0; - auto sentences = read_lines_from_stdin(batch_limit, &processed); + auto sentences = read_lines(std::cin, batch_limit, &processed); if (output_type == SUBWORD) { std::vector> subwords; Status status = encode_as_subwords(sentences, &subwords, bos, eos, reverse, dropout_prob); diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index e40ffc0..901207c 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -4,7 +4,7 @@ #include #include -#include "third_party/flat_hash_map/flat_hash_map/flat_hash_map.h" +#include "third_party/flat_hash_map/flat_hash_map.h" #include "utils.h" diff --git a/youtokentome/cpp/utf8.cpp b/youtokentome/cpp/utf8.cpp index 02f05a5..8b2b2ca 100644 --- a/youtokentome/cpp/utf8.cpp +++ b/youtokentome/cpp/utf8.cpp @@ -26,6 +26,8 @@ bool is_spacing_char(uint32_t ch) { return is_space(ch) || is_punctuation(ch) || bool check_byte(char x) { return (static_cast(x) & 0xc0u) == 0x80u; } +bool check_symbol_start(char x) { return !check_byte(x); }; + bool check_codepoint(uint32_t x) { return (x < 0xd800) || (0xdfff < x && x < 0x110000); } uint64_t utf_length(char ch) { diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 79fdc94..2bf42df 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -43,10 +43,10 @@ uint64_t SpecialTokens::n_special_tokens() const { SpecialTokens::SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id) : pad_id(pad_id), unk_id(unk_id), bos_id(bos_id), eos_id(eos_id) {} -std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) { +std::vector read_lines(std::istream& stream, uint64_t batch_limit, uint64_t *processed) { std::vector sentences; std::string s; - while (*processed < batch_limit && std::getline(std::cin, s)) { + while (*processed < batch_limit && std::getline(stream, s)) { *processed += s.size(); sentences.push_back(std::move(s)); } diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index 8010e74..1b38d5d 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -44,7 +45,7 @@ struct SpecialTokens { uint64_t n_special_tokens() const; }; -std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed); +std::vector read_lines(std::istream& stream, uint64_t batch_limit, uint64_t *processed); Status fast_read_file_utf8(const std::string &file_name, std::string *file_content); @@ -69,20 +70,17 @@ void write_to_stdout(const std::vector> &sentences, bool flush) { } } -class VectorSegmentBuilder; - +template struct VectorSegment { private: - friend class VectorSegmentBuilder; + const T *begin_; + const T *end_; + uint64_t hash_; - const uint32_t *begin_; - const uint32_t *end_; - const uint64_t hash_; - - VectorSegment(const uint32_t *begin, const uint32_t *end, uint64_t hash) + public: + VectorSegment(const T *begin, const T *end, uint64_t hash) : begin_(begin), end_(end), hash_(hash) {} - public: bool operator==(const VectorSegment &other) const { if (other.hash() != hash() || end_ - begin_ != other.end_ - other.begin_) { return false; @@ -98,29 +96,31 @@ struct VectorSegment { uint64_t hash() const { return hash_; } }; +template class VectorSegmentBuilder { private: constexpr static uint64_t MOD = 2032191299; constexpr static uint64_t P = 726328703; - const uint32_t *begin_; - const uint32_t *end_; + const T *begin_; + const T *end_; std::vector prefix_hash_; public: - VectorSegmentBuilder(const std::vector &segment) + explicit VectorSegmentBuilder(const std::vector &segment) : VectorSegmentBuilder(segment.data(), segment.data() + segment.size()) {} - VectorSegmentBuilder(const uint32_t *begin, const uint32_t *end) : begin_(begin), end_(end) { + VectorSegmentBuilder(const T *begin, const T *end) : begin_(begin), end_(end) { + using HashT = typename std::make_unsigned::type; uint64_t hash = 0; prefix_hash_.reserve(static_cast(end - begin)); - for (const uint32_t *it = begin_; it != end_; it++) { - hash = (hash * P + *it) % MOD; + for (const T *it = begin_; it != end_; it++) { + hash = (hash * P + static_cast(*it)) % MOD; prefix_hash_.push_back(hash); } } - VectorSegment finish() const { return VectorSegment(begin_, end_, hash()); } + VectorSegment finish() const { return VectorSegment(begin_, end_, hash()); } size_t size() const { return prefix_hash_.size(); } @@ -136,13 +136,16 @@ class VectorSegmentBuilder { } }; +using BpeVectorSegment = VectorSegment; +using WordPieceVectorSegment = VectorSegment; + } // namespace vkcom namespace std { -template <> -struct hash { - uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash(); } +template +struct hash> { + uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash(); } }; -} // namespace std \ No newline at end of file +} // namespace std diff --git a/youtokentome/cpp/wordpiece.cpp b/youtokentome/cpp/wordpiece.cpp index 3713b2b..8551d5f 100644 --- a/youtokentome/cpp/wordpiece.cpp +++ b/youtokentome/cpp/wordpiece.cpp @@ -3,12 +3,13 @@ #include #include #include +#include #include #include #include #include "third_party/flat_hash_map/flat_hash_map.h" -#include "third_party/thread_pool/thread_pool.hpp" +#include "third_party/thread_pool/thread_pool.h" #include "utf8.h" namespace { @@ -18,6 +19,16 @@ const std::string PAD_TOKEN = "[PAD]"; const std::string BOS_TOKEN = "[BOS]"; const std::string EOS_TOKEN = "[EOS]"; +bool isSuffixVocab(const std::vector &word) { + static const uint32_t kSharp = static_cast('#'); + return word.size() >= 2 && word[0] == kSharp && word[1] == kSharp; +} + +bool isSpecialToken(const std::vector &word) { + return word.size() > 2 && word[0] == static_cast('[') + && word.back() == static_cast(']'); +} + struct WordPieceToken { explicit WordPieceToken(const std::string &encoded_word) : is_prefix(true), is_special(false), is_malformed(false), @@ -66,7 +77,6 @@ struct WordPieceVocabulary { } explicit WordPieceVocabulary(const std::string &file) { - WordPieceVocabulary vocab_utf8; std::ifstream fin(file); std::string word; int token_id = 0; @@ -79,7 +89,7 @@ struct WordPieceVocabulary { } std::vector tokens; - SpecialTokens special_tokens; + vkcom::SpecialTokens special_tokens; private: void update_special_tokens(const std::string& word, int token_id) { @@ -95,10 +105,54 @@ struct WordPieceVocabulary { } }; +std::vector parseText(const char *text, size_t size, vkcom::ThreadPool &thread_pool) { + static const size_t kWorkBatch = 5000000; + + if (size < 2 * kWorkBatch) { + return vkcom::decode_utf8(text, text + size); + } else { + const size_t thread_count = std::min(thread_pool.maxThreads(), size / kWorkBatch); + const size_t work_batch = size / thread_count + 1; + std::vector> per_thread_text_utf8(thread_count); + size_t work_start = 0; + for (size_t thread_id = 0; thread_id < thread_count && work_start < size; thread_id++) { + size_t work_end = std::min(size, work_start + work_batch); + while (work_end < size && !vkcom::check_symbol_start(text[work_end])) { + ++work_end; + } + thread_pool.submit([thread_id, work_start, work_end, text, &per_thread_text_utf8] { + const char *begin = text + work_start; + const size_t len = work_end - work_start; + per_thread_text_utf8[thread_id] = vkcom::decode_utf8(begin, begin + len); + }); + work_start = work_end; + } + + thread_pool.waitCompletion(); + size_t text_utf8_size = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + text_utf8_size += per_thread_text_utf8[thread_id].size(); + } + std::vector text_utf8(text_utf8_size); + text_utf8.resize(text_utf8_size); + work_start = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + std::vector &segment = per_thread_text_utf8[thread_id]; + if (!segment.empty()) { + std::memcpy(text_utf8.data() + work_start, + segment.data(), + segment.size() * sizeof(uint32_t)); + work_start += segment.size(); + } + } + return text_utf8; + } +} + std::vector encode_word_piece_impl(const std::vector &text, const WordPieceVocabulary &vocab, vkcom::ThreadPool& thread_pool) { - using WordMap = std::unordered_map; + using WordMap = std::unordered_map; WordMap prefix_to_id; // no ## in word prefix WordMap suffix_to_id; // ## in word prefix @@ -109,7 +163,7 @@ std::vector encode_word_piece_impl(const std::vector &text, continue; } max_len = std::max(max_len, token.word.size()); - vkcom::VectorSegmentBuilder segment(token.word); + vkcom::VectorSegmentBuilder segment(token.word); WordMap *word_to_id = token.is_prefix ? &prefix_to_id : &suffix_to_id; (*word_to_id)[segment.finish()] = static_cast(i); } @@ -119,8 +173,9 @@ std::vector encode_word_piece_impl(const std::vector &text, return index == 0 || vkcom::is_spacing_char(text[index]) || vkcom::is_spacing_char(text[index - 1]); }; + const int unk_token_id = vocab.special_tokens.unk_id; - const auto worker = [&, unk_token_id = vocab.special_tokens.unk_id](size_t begin, size_t end) { + const auto worker = [&, unk_token_id](size_t begin, size_t end) { std::vector token_ids; token_ids.reserve((end - begin) / max_len + 1); @@ -143,7 +198,7 @@ std::vector encode_word_piece_impl(const std::vector &text, const uint32_t *segment_end = segment_begin + static_cast(word_len); const WordMap *word_to_id = is_word_prefix(begin) ? &prefix_to_id : &suffix_to_id; - vkcom::VectorSegmentBuilder segment(segment_begin, segment_end); + vkcom::VectorSegmentBuilder segment(segment_begin, segment_end); while (!segment.empty()) { auto it = word_to_id->find(segment.finish()); if (it != word_to_id->end()) { @@ -178,8 +233,9 @@ std::vector encode_word_piece_impl(const std::vector &text, return token_ids; }; - static constexpr size_t kWorkBatch = 1'000'000; + static const size_t kWorkBatch = 1000000; std::vector token_ids; + if (text.size() < 2 * kWorkBatch) { token_ids = worker(0, text.size()); } else { @@ -223,7 +279,7 @@ std::vector encode_word_piece(const char *text, size_t size, const WordPiec return {}; } vkcom::ThreadPool thread_pool(0); - const std::vector text_utf8 = utils::parseText(text, size, thread_pool); + const std::vector text_utf8 = parseText(text, size, thread_pool); return encode_word_piece_impl(text_utf8, vocab, thread_pool); } @@ -233,26 +289,20 @@ namespace vkcom::wordpiece { Status encode_as_ids(const std::string &text_path, const std::string& vocab_path, std::vector *ids) { - const uint64_t batch_limit = 10 * 1024 * 1024; + const uint64_t kBatchLimit = 10 * 1024 * 1024; + try { std::string text; Status status = fast_read_file_utf8(text_path, &text); if (!status.ok()) { return status; } - uint64_t processed = 0; - std::vector vocab = read_lines_from_stdin(batch_limit, &processed); - return encode_as_ids(text, vocab, ids); - } catch (const std::exception& ex) { - return Status(1, ex.what()); - } catch (...) { - return Status(1, "Unknown error"); - } -} - -Status encode_as_ids(const std::string &text, - const std::vector& vocab, std::vector *ids) { - try { + std::vector vocab; + { + std::ifstream fin(vocab_path); + uint64_t processed = 0; + vocab = read_lines(fin, kBatchLimit, &processed); + } WordPieceVocabulary word_piece_vocab(vocab); *ids = encode_word_piece(text.data(), text.size(), word_piece_vocab); return Status(); @@ -266,26 +316,20 @@ Status encode_as_ids(const std::string &text, Status encode_as_subwords(const std::string &text_path, const std::string& vocab_path, std::vector *subwords) { + const uint64_t kBatchLimit = 10 * 1024 * 1024; + try { std::string text; Status status = fast_read_file_utf8(text_path, &text); if (!status.ok()) { return status; } - uint64_t processed = 0; - std::vector vocab = read_lines_from_stdin(batch_limit, &processed); - return encode_as_subwords(text, vocab, subwords); - } catch (const std::exception& ex) { - return Status(1, ex.what()); - } catch (...) { - return Status(1, "Unknown error"); - } -} - -Status encode_as_subwords(const std::string &text, - const std::vector& vocab, - std::vector *subwords) { - try { + std::vector vocab; + { + std::ifstream fin(vocab_path); + uint64_t processed = 0; + vocab = read_lines(fin, kBatchLimit, &processed); + } WordPieceVocabulary word_piece_vocab(vocab); std::vector ids = encode_word_piece(text.data(), text.size(), word_piece_vocab); for (int id : ids) { @@ -300,13 +344,21 @@ Status encode_as_subwords(const std::string &text, } Status decode(const std::vector& ids, - const std::vector& vocab, + const std::string& vocab_path, std::vector *subwords, const std::unordered_set *ignore_ids) { + const uint64_t kBatchLimit = 10 * 1024 * 1024; + try { + std::vector vocab; + { + std::ifstream fin(vocab_path); + uint64_t processed = 0; + vocab = read_lines(fin, kBatchLimit, &processed); + } for (int id : ids) { if (!ignore_ids || ignore_ids->count(id) == 0) { - subwords->push_back(vocab[id]); + subwords->push_back(vocab.at(id)); } } return Status(); @@ -317,4 +369,4 @@ Status decode(const std::vector& ids, } } -} // namespace vkcom::wordpiece \ No newline at end of file +} // namespace vkcom::wordpiece diff --git a/youtokentome/cpp/wordpiece.h b/youtokentome/cpp/wordpiece.h index 92bc1b8..1a953d0 100644 --- a/youtokentome/cpp/wordpiece.h +++ b/youtokentome/cpp/wordpiece.h @@ -4,25 +4,20 @@ #include #include +#include "utils.h" + namespace vkcom::wordpiece { Status encode_as_ids(const std::string &text_path, const std::string& vocab_path, std::vector *ids); -Status encode_as_ids(const std::string &text, - const std::vector& vocab, std::vector *ids); - Status encode_as_subwords(const std::string &text_path, const std::string& vocab_path, std::vector *subwords); -Status encode_as_subwords(const std::string &text, - const std::vector& vocab, - std::vector *subwords); - Status decode(const std::vector& ids, - const std::vector& vocab, + const std::string& vocab_path, std::vector *subwords, const std::unordered_set *ignore_ids); -} // namespace vkcom::wordpiece \ No newline at end of file +} // namespace vkcom::wordpiece diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index a7964d0..652c836 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -177,41 +177,45 @@ cdef class BPE: cdef extern from "wordpiece.h" namespace "vkcom::wordpiece": Status encode_as_ids(const string &text_path, const string &vocab_path, vector[int] *ids) - Status encode_as_ids(const string &text, const vector[string] &vocab, vector[int] *ids) Status encode_as_subwords(const string &text_path, const string &vocab_path, vector[string] *subwords) - Status encode_as_subwords(const string &text, const vector[string] &vocab, vector[string] *subwords) - Status decode(const vector[int] &ids, const vector[string] &vocab, vector[string] *subwords, const unordered_set[int] *ignore_ids) + Status decode(const vector[int] &ids, const string &vocab_path, vector[string] *subwords, const unordered_set[int] *ignore_ids) cdef class WordPiece: - def __init__(self, vocab, n_threads=0): - self.vocab = vocab + cdef string vocab_path + cdef int n_threads + + def __dealloc__(self): + pass + + def __init__(self, vocab_path, n_threads=0): + self.vocab_path = vocab_path.encode() self.n_threads = n_threads - def encode(self, text, output_type): + def encode(self, text_path, output_type): cdef Status status + cdef vector[int] ids + cdef vector[string] subwords if output_type == 'id': - cdef vector[int] ids - status = encode_as_ids(text, self.vocab, &ids) + status = encode_as_ids(text_path.encode(), self.vocab_path, &ids) if status.code != 0: raise ValueError(status.message.decode()) return ids elif output_type == 'subword': - cdef vector[string] subwords - status = encode_as_subwords(text, self.vocab, &subwords) + status = encode_as_subwords(text_path.encode(), self.vocab_path, &subwords) if status.code != 0: raise ValueError(status.message.decode()) return subwords else: raise ValueError('output_type must be equal to "id" or "subword"') - def decode(self, ids, ignore_ids) + def decode(self, ids, ignore_ids): if ignore_ids is None: ignore_ids = set() cdef unordered_set[int] c_ignore_ids = unordered_set[int](ignore_ids) cdef vector[string] subwords - cdef Status status = decode(ids, self.vocab, &subwords, &c_ignore_ids) + cdef Status status = decode(ids, self.vocab_path, &subwords, &c_ignore_ids) if status.code != 0: raise ValueError(status.message.decode()) return subwords diff --git a/youtokentome/youtokentome.py b/youtokentome/youtokentome.py index 593febf..7b3ded6 100644 --- a/youtokentome/youtokentome.py +++ b/youtokentome/youtokentome.py @@ -97,3 +97,40 @@ def __setstate__(self, dict): self.bpe_cython = _youtokentome_cython.BPE( model_path=self.model, n_threads=self.n_threads ) + + +class WordPiece: + def __init__(self, vocab_path: str, n_threads: int = 0): + self.vocab_path = vocab_path + self.n_threads = n_threads + + self.word_piece_cython = _youtokentome_cython.WordPiece( + vocab_path=vocab_path, n_threads=n_threads + ) + + def encode( + self, + text_path: str, + output_type: OutputType = OutputType.ID + ) -> Union[List[List[int]], List[List[str]]]: + output_type_str = "id" if output_type == OutputType.ID else "subword" + return self.word_piece_cython.encode(text_path, output_type_str) + + def decode( + self, + ids: List[int], + ignore_ids: Optional[Collection[int]] = None + ) -> List[str]: + return self.word_piece_cython.decode(ids, ignore_ids) + + def __getstate__(self): + return {"vocab_path": self.vocab_path, "n_threads": self.n_threads} + + def __setstate__(self, dict): + self.vocab_path = dict["vocab_path"] + self.n_threads = dict["n_threads"] + + self.word_piece_cython = _youtokentome_cython.WordPiece( + vocab_path=vocab_path, n_threads=self.n_threads + ) + From d7741ec1b85f5ad540f4e4aabe70d69ae4ef3431 Mon Sep 17 00:00:00 2001 From: Gleb Koveshnikov Date: Sat, 8 Apr 2023 20:41:09 +0400 Subject: [PATCH 08/10] encoder class, stress tests --- tests/unit_tests/bpe/test_stress.py | 10 +- tests/unit_tests/wordpiece/stress_test.cpp | 176 ++++++++ tests/unit_tests/wordpiece/test_stress.py | 47 ++ .../cpp/third_party/thread_pool/thread_pool.h | 8 +- youtokentome/cpp/wordpiece.cpp | 407 ++++++++---------- youtokentome/cpp/wordpiece.h | 67 ++- youtokentome/cpp/yttm.pyx | 26 +- 7 files changed, 477 insertions(+), 264 deletions(-) create mode 100644 tests/unit_tests/wordpiece/stress_test.cpp create mode 100644 tests/unit_tests/wordpiece/test_stress.py diff --git a/tests/unit_tests/bpe/test_stress.py b/tests/unit_tests/bpe/test_stress.py index 92ca8b7..03e94b0 100644 --- a/tests/unit_tests/bpe/test_stress.py +++ b/tests/unit_tests/bpe/test_stress.py @@ -12,13 +12,13 @@ def compile_test(): files = ["../../../youtokentome/cpp/" + file_name for file_name in build_files] files.append("stress_test.cpp") - print("compiling stress test ...") + print("compiling bpe stress test ...") command = [ "g++", *files, "-o", - "stress", + "bpe_stress", "-std=c++11", "-pthread", "-Og", @@ -35,16 +35,16 @@ def compile_test(): def test_stress(): compile_test() - run(["./stress", "base", "1000"], check=True) + run(["./bpe_stress", "base", "1000"], check=True) def test_manual(): compile_test() - run(["./stress", "manual"], check=True) + run(["./bpe_stress", "manual"], check=True) os.remove("remove_it.txt") def test_parallel(): compile_test() - run(["./stress", "parallel", "50"], check=True) + run(["./bpe_stress", "parallel", "50"], check=True) os.remove("remove_it.txt") diff --git a/tests/unit_tests/wordpiece/stress_test.cpp b/tests/unit_tests/wordpiece/stress_test.cpp new file mode 100644 index 0000000..6253a83 --- /dev/null +++ b/tests/unit_tests/wordpiece/stress_test.cpp @@ -0,0 +1,176 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../youtokentome/cpp/wordpiece.h" + +namespace vkcom { + +struct TestCase { + std::string text; + std::vector vocab; + std::vector answer_encoded; + std::vector answer_decoded; +}; + +template +void dump_vector(const std::string &filename, const std::vector &vec, char delim) { + std::ofstream fout(filename); + for (const auto& item : vec) { + fout << item << delim; + } +} + +void dump_test_case(const TestCase &test_case) { + { + std::ofstream fout("stress.txt"); + fout << test_case.text; + } + dump_vector("vocab.txt", text.vocab, '\n'); + dump_vector("anwer_encoded.txt", text.anwer_encoded, ' '); + dump_vector("answer_decoded.txt", text.answer_decoded, ' '); +} + +void check(const TestCase &test_case, const std::vector &encoded, const std::vector &decoded) { + if (encoded != test_case.answer_encoded || decoded != test_case.answer_decoded) { + dump_test_case(test_case); + throw std::runtime_error("STRESS TEST FAILED, test case dumped"); + } +} + +std::string get_random_string(std::mt19937 &rnd, size_t string_length) { + static constexpr std::string_view kAllChars = "abcdefghijklmnopqrstuvwxyz"; + std::string result; + result.reserve(string_length); + while (string_length > 0) { + --string_length; + size_t index = std::uniform_int_distribution(0ul, kAllChars.size() - 1)(rnd); + result.push_back(kAllChars[index]); + } + return result; +} + +TestCase generate_test_case(size_t text_len, size_t parts) { + std::mt19937 rnd(17); + std::string text; + text.reserve(text_len + parts); + std::uniform_int_distribution word_len(0ul, text_len / 2); + + std::unordered_map vocab_map; + std::vector answer_encoded; + std::vector answer_decoded; + text.reserve(parts); + + for (size_t i = 0; i < parts; i++) { + const size_t vocab_size = vocab_map.size(); + if (i + 1 == parts) { + size_t leftover = text.capacity() > text.size() ? text.capacity() - text.size() : 0; + std::string word = get_random_string(rnd, leftover); + if (vocab_map[word] == 0) { + vocab_map[word] = static_cast(vocab_size); + } + text.append(word); + answer_encoded.push_back(vocab_map[word]); + answer_decoded.push_back(std::move(word)); + } else if (text % 10 == 0) { + std::uniform_int_distribution rnd_word(0ul, vocab_size); + auto it = std::next(vocab_map.begin(), rnd_word(rnd)); + text.append(it->first); + text.append(' '); + answer_encoded.push_back(it->second); + answer_decoded.push_back(it->first); + } else { + std::string word = get_random_string(rnd, word_len(rnd)); + if (vocab_map[word] == 0) { + vocab_map[word] = static_cast(vocab_size); + } + text.append(word); + text.append(' '); + answer_encoded.push_back(vocab_map[word]); + answer_decoded.push_back(std::move(word)); + } + } + + std::vector vocab; + vocab.resize(vocab_map.size()); + for (auto it = vocab_map.begin(); it != vocab_map.end(); it++) { + vocab[it->second] = it->first; + } + return TestCase{std::move(text), std::move(vocab), std::move(answer_encoded), std::move(answer_decoded)}; +} + +void test_stress(size_t text_len_from, + size_t text_len_to, + size_t text_len_step, + size_t parts_from, + size_t parts_to, + int n_threads) { + for (size_t text_len = text_len_from; text_len <= text_len_to; text_len += text_len_step) { + for (size_t parts = std::min(text_len, parts_from); parts <= std::min(text_len, parts_to); + parts++) { + + for (int i = 0; i < 3; i++) { + TestCase test_case = generate_test_case(text_len, parts, positive); + std::cout << "running stress, text_len " << test_case.text.size() << ", vocab_size " + << test_case.vocab.size() << std::endl; + + Status status; + std::vector encoded; + wordpiece::Encoder encoder(test_case.vocab, n_threads); + status = encoder.encode_as_ids(test_case.text, &encoded); + if (!status.ok()) { + dump_test_case(test_case); + throw std::runtime_error("encode_as_ids failed, test_case dumped"); + } + std::vector decoded; + status = encoder.encode_as_subwords(test_case.text, test_case.vocab, &decoded); + if (!status.ok()) { + dump_test_case(test_case); + throw std::runtime_error("encode_as_subwords failed, test_case dumped"); + } + + check(test_case, encoded, decoded); + } + } + } +} + +void run_small(int n_threads) { + test_stress(10, 300, 5, 2, 100, n_threads); + test_stress(10, 300, 5, 2, 100, n_threads); +} + +void run_large(int n_threads) { + test_stress(100'000, + 1'000'000, + 400'000, + kWordPieceVocabSize, + kWordPieceVocabSize, + n_threads); + test_stress(10'000'000, + 10'000'000, + 200'000, + kWordPieceVocabSize, + kWordPieceVocabSize, + n_threads); +} + +int main(int argc, char **argv) { + if (argc == 2 && argv[1] == "small") { + run_small(1); + } else if (argc == 2 && argv[1] == "large") { + run_large(1); + } else if (argc == 2 && argv[1] == "parallel") { + run_small(0); + run_large(0); + } else { + assert(false); + } +} + +} // namespace vkcom diff --git a/tests/unit_tests/wordpiece/test_stress.py b/tests/unit_tests/wordpiece/test_stress.py new file mode 100644 index 0000000..c28e2be --- /dev/null +++ b/tests/unit_tests/wordpiece/test_stress.py @@ -0,0 +1,47 @@ +import os +from subprocess import run + + +tests_compiled = False + +def compile_test(): + global tests_compiled + if tests_compiled: + return + build_files = ["wordpiece.cpp", "utils.cpp", "utf8.cpp"] + files = ["../../youtokentome/cpp/" + file_name for file_name in build_files] + files.append("stress_test.cpp") + + print("compiling wordpiece stress test ...") + + command = [ + "g++", + *files, + "-o", + "wordpiece_stress", + "-std=c++11", + "-pthread", + "-Og", + "-D_GLIBCXX_DEBUG", + "-fno-omit-frame-pointer -fsanitize=address -fsanitize=leak -fsanitize=undefined", + ] + + command = " ".join(command) + print("command:", command) + run(command, check=True, shell=True) + tests_compiled = True + + +def test_small(): + compile_test() + run(["./wordpiece_stress", "small"], check=True) + + +def test_manual(): + compile_test() + run(["./wordpiece_stress", "large"], check=True) + + +def test_parallel(): + compile_test() + run(["./wordpiece_stress", "parallel"], check=True) diff --git a/youtokentome/cpp/third_party/thread_pool/thread_pool.h b/youtokentome/cpp/third_party/thread_pool/thread_pool.h index 96884f6..459cbc9 100644 --- a/youtokentome/cpp/third_party/thread_pool/thread_pool.h +++ b/youtokentome/cpp/third_party/thread_pool/thread_pool.h @@ -18,14 +18,14 @@ class ThreadPool { using Task = std::function; public: - ThreadPool(size_t thread_count) { - if (thread_count == 0) { - thread_count = static_cast(std::thread::hardware_concurrency()); + ThreadPool(int thread_count) { + if (thread_count <= 0) { + thread_count = static_cast(std::thread::hardware_concurrency()); } if (thread_count == 0) { thread_count = 8; } - for (size_t thread = 0; thread < thread_count; ++thread) { + for (int thread = 0; thread < thread_count; ++thread) { threads_.emplace_back([this] { while (!stop_.load(std::memory_order_relaxed)) { std::unique_lock lock(mutex_); diff --git a/youtokentome/cpp/wordpiece.cpp b/youtokentome/cpp/wordpiece.cpp index 8551d5f..3682dc2 100644 --- a/youtokentome/cpp/wordpiece.cpp +++ b/youtokentome/cpp/wordpiece.cpp @@ -4,12 +4,7 @@ #include #include #include -#include -#include -#include -#include "third_party/flat_hash_map/flat_hash_map.h" -#include "third_party/thread_pool/thread_pool.h" #include "utf8.h" namespace { @@ -19,24 +14,72 @@ const std::string PAD_TOKEN = "[PAD]"; const std::string BOS_TOKEN = "[BOS]"; const std::string EOS_TOKEN = "[EOS]"; -bool isSuffixVocab(const std::vector &word) { +bool is_suffix_vocab(const std::vector &word) { static const uint32_t kSharp = static_cast('#'); return word.size() >= 2 && word[0] == kSharp && word[1] == kSharp; } -bool isSpecialToken(const std::vector &word) { +bool is_special_token(const std::vector &word) { return word.size() > 2 && word[0] == static_cast('[') && word.back() == static_cast(']'); } -struct WordPieceToken { - explicit WordPieceToken(const std::string &encoded_word) +std::vector parse_text(const char *text, size_t size, vkcom::ThreadPool &thread_pool) { + static const size_t kWorkBatch = 5000000; + + if (size < 2 * kWorkBatch) { + return vkcom::decode_utf8(text, text + size); + } + + const size_t thread_count = std::min(thread_pool.maxThreads(), size / kWorkBatch); + const size_t work_batch = size / thread_count + 1; + std::vector> per_thread_text_utf8(thread_count); + size_t work_start = 0; + for (size_t thread_id = 0; thread_id < thread_count && work_start < size; thread_id++) { + size_t work_end = std::min(size, work_start + work_batch); + while (work_end < size && !vkcom::check_symbol_start(text[work_end])) { + ++work_end; + } + thread_pool.submit([thread_id, work_start, work_end, text, &per_thread_text_utf8] { + const char *begin = text + work_start; + const size_t len = work_end - work_start; + per_thread_text_utf8[thread_id] = vkcom::decode_utf8(begin, begin + len); + }); + work_start = work_end; + } + + thread_pool.waitCompletion(); + size_t text_utf8_size = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + text_utf8_size += per_thread_text_utf8[thread_id].size(); + } + std::vector text_utf8(text_utf8_size); + text_utf8.resize(text_utf8_size); + work_start = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + std::vector &segment = per_thread_text_utf8[thread_id]; + if (!segment.empty()) { + std::memcpy(text_utf8.data() + work_start, + segment.data(), + segment.size() * sizeof(uint32_t)); + work_start += segment.size(); + } + } + + return text_utf8; +} + +} // namespace + +namespace vkcom::wordpiece { + +WordPieceToken::WordPieceToken(const std::string &encoded_word) : is_prefix(true), is_special(false), is_malformed(false), word(vkcom::decode_utf8(encoded_word)) { - if (isSuffixVocab(word)) { + if (is_suffix_vocab(word)) { is_prefix = false; word.erase(word.begin(), word.begin() + 2); - } else if (isSpecialToken(word)) { + } else if (is_special_token(word)) { is_special = true; } @@ -56,43 +99,22 @@ struct WordPieceToken { is_malformed = true; std::cerr << "Vocab word is malformed: " << encoded_word << std::endl; } - } - - bool is_prefix; - bool is_special; - bool is_malformed; - std::vector word; -}; +} -struct WordPieceVocabulary { - explicit WordPieceVocabulary(const std::vector& words) { +WordPieceVocabulary::WordPieceVocabulary(const std::vector& words) { tokens.reserve(words.size()); int token_id = 0; + max_token_len = 0; for (const std::string& word : words) { update_special_tokens(word, token_id); WordPieceToken token(word); + max_token_len = std::max(max_token_len, token.word.size()); tokens.push_back(std::move(token)); ++token_id; } - } - - explicit WordPieceVocabulary(const std::string &file) { - std::ifstream fin(file); - std::string word; - int token_id = 0; - while (std::getline(fin, word)) { - update_special_tokens(word, token_id); - WordPieceToken token(word); - tokens.push_back(std::move(token)); - ++token_id; - } - } - - std::vector tokens; - vkcom::SpecialTokens special_tokens; +} -private: - void update_special_tokens(const std::string& word, int token_id) { +void WordPieceVocabulary::update_special_tokens(const std::string& word, int token_id) { if (word == UNK_TOKEN) { special_tokens.unk_id = token_id; } else if (word == PAD_TOKEN) { @@ -102,91 +124,142 @@ struct WordPieceVocabulary { } else if (word == EOS_TOKEN) { special_tokens.eos_id = token_id; } - } -}; +} -std::vector parseText(const char *text, size_t size, vkcom::ThreadPool &thread_pool) { - static const size_t kWorkBatch = 5000000; +Encoder::Encoder(const std::string &vocab_path, int n_threads) : Encoder(read_lines(std::ifstream(vocab_path)), n_threads) {} - if (size < 2 * kWorkBatch) { - return vkcom::decode_utf8(text, text + size); - } else { - const size_t thread_count = std::min(thread_pool.maxThreads(), size / kWorkBatch); - const size_t work_batch = size / thread_count + 1; - std::vector> per_thread_text_utf8(thread_count); - size_t work_start = 0; - for (size_t thread_id = 0; thread_id < thread_count && work_start < size; thread_id++) { - size_t work_end = std::min(size, work_start + work_batch); - while (work_end < size && !vkcom::check_symbol_start(text[work_end])) { - ++work_end; - } - thread_pool.submit([thread_id, work_start, work_end, text, &per_thread_text_utf8] { - const char *begin = text + work_start; - const size_t len = work_end - work_start; - per_thread_text_utf8[thread_id] = vkcom::decode_utf8(begin, begin + len); - }); - work_start = work_end; +Encoder::Encoder(std::vector vocab, int n_threads) : vocab_(std::move(vocab)), word_piece_vocab_(vocab_), thread_pool_(n_threads) { + build_word_maps(); +} + +Status Encoder::encode_as_ids(const std::string &text_path, std::vector *ids) const { + try { + std::string text_str; + Status status = fast_read_file_utf8(text_path, &text_str); + if (!status.ok()) { + return status; } + const std::vector text = parse_text(text_str, text_str.size(), thread_pool_); + *ids = encode_parallel(text); + return Status(); + } catch (const std::exception &ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } +} - thread_pool.waitCompletion(); - size_t text_utf8_size = 0; - for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { - text_utf8_size += per_thread_text_utf8[thread_id].size(); +Status Encoder::encode_as_subwords(const std::string &text_path, + std::vector *subwords) const { + try { + std::string text_str; + Status status = fast_read_file_utf8(text_path, &text_str); + if (!status.ok()) { + return status; } - std::vector text_utf8(text_utf8_size); - text_utf8.resize(text_utf8_size); - work_start = 0; - for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { - std::vector &segment = per_thread_text_utf8[thread_id]; - if (!segment.empty()) { - std::memcpy(text_utf8.data() + work_start, - segment.data(), - segment.size() * sizeof(uint32_t)); - work_start += segment.size(); + const std::vector text = parse_text(text_str, text_str.size(), thread_pool_); + std::vector ids = encode_parallel(text); + for (int id : ids) { + subwords->push_back(vocab[id]); + } + return Status(); + } catch (const std::exception &ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } +} + +Status Encoder::decode(const std::vector &ids, + std::vector *subwords, + const std::unordered_set *ignore_ids) const { + try { + for (int id : ids) { + if (!ignore_ids || ignore_ids->count(id) == 0) { + subwords->push_back(vocab_.at(id)); } } - return text_utf8; + return Status(); + } catch (const std::exception &ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); } } -std::vector encode_word_piece_impl(const std::vector &text, - const WordPieceVocabulary &vocab, - vkcom::ThreadPool& thread_pool) { - using WordMap = std::unordered_map; - WordMap prefix_to_id; // no ## in word prefix - WordMap suffix_to_id; // ## in word prefix +bool Encoder::is_word_prefix(const std::vector &text, size_t index) { + return index == 0 || vkcom::is_spacing_char(text[index]) + || vkcom::is_spacing_char(text[index - 1]); +} - size_t max_len = 0; +void Encoder::build_word_maps() { for (size_t i = 0; i < vocab.tokens.size(); i++) { const auto &token = vocab.tokens[i]; if (token.is_special || token.is_malformed) { continue; } - max_len = std::max(max_len, token.word.size()); vkcom::VectorSegmentBuilder segment(token.word); - WordMap *word_to_id = token.is_prefix ? &prefix_to_id : &suffix_to_id; + WordMap *word_to_id = token.is_prefix ? &prefix_to_id_ : &suffix_to_id_; (*word_to_id)[segment.finish()] = static_cast(i); } - max_len = std::min(max_len, text.size()); +} - const auto is_word_prefix = [&text](size_t index) { - return index == 0 || vkcom::is_spacing_char(text[index]) - || vkcom::is_spacing_char(text[index - 1]); - }; - const int unk_token_id = vocab.special_tokens.unk_id; +std::vector Encoder::encode_parallel(const std::vector &text) const { + static const size_t kWorkBatch = 1000000; - const auto worker = [&, unk_token_id](size_t begin, size_t end) { - std::vector token_ids; - token_ids.reserve((end - begin) / max_len + 1); + if (text.size() < 2 * kWorkBatch) { + return encode_impl(text, 0, text.size()); + } - while (begin != end && vkcom::is_space(text[begin])) { - ++begin; + const size_t thread_count = std::min(thread_pool_.maxThreads(), text.size() / kWorkBatch); + const size_t work_batch = text.size() / thread_count + 1; + std::vector> per_thread_token_ids(thread_count); + size_t work_begin = 0; + for (size_t thread_id = 0; thread_id < thread_count && work_begin < text.size(); thread_id++) { + size_t work_end = std::min(text.size(), work_begin + work_batch); + while (work_end < text.size() && !vkcom::is_space(text[work_end])) { + ++work_end; } + thread_pool_.submit([thread_id, work_begin, work_end, &per_thread_token_ids, &worker, &text] { + per_thread_token_ids[thread_id] = encode_impl(text, work_begin, work_end); + }); + work_begin = work_end; + } + + thread_pool_.waitCompletion(); + + size_t token_count = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + token_count += per_thread_token_ids[thread_id].size(); + } + std::vector token_ids(token_count); + work_begin = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + std::vector &segment = per_thread_token_ids[thread_id]; + if (!segment.empty()) { + std::memcpy(token_ids.data() + work_begin, segment.data(), segment.size() * sizeof(int)); + work_begin += segment.size(); + } + } + + return token_ids; +} + +std::vector Encoder::encode_impl(const std::vector &text, size_t begin, size_t end) const { + size_t max_len = std::min(word_piece_vocab_.max_token_len, end - begin); + const int unk_token_id = vocab.special_tokens.unk_id; - size_t tokens_since_prefix = 0; + std::vector token_ids; + token_ids.reserve((end - begin) / max_len + 1); + + while (begin != end && vkcom::is_space(text[begin])) { + ++begin; + } - while (begin != end) { - size_t word_len = 1; + size_t tokens_since_prefix = 0; + + while (begin != end) { + size_t word_len = 1; if (!vkcom::is_punctuation(text[begin])) { while (word_len < std::min(max_len, end - begin) && !vkcom::is_spacing_char(text[begin + word_len])) { @@ -196,7 +269,7 @@ std::vector encode_word_piece_impl(const std::vector &text, const uint32_t *segment_begin = text.data() + static_cast(begin); const uint32_t *segment_end = segment_begin + static_cast(word_len); - const WordMap *word_to_id = is_word_prefix(begin) ? &prefix_to_id : &suffix_to_id; + const WordMap *word_to_id = is_word_prefix(begin) ? &prefix_to_id_ : &suffix_to_id_; vkcom::VectorSegmentBuilder segment(segment_begin, segment_end); while (!segment.empty()) { @@ -228,145 +301,9 @@ std::vector encode_word_piece_impl(const std::vector &text, while (begin != end && vkcom::is_space(text[begin])) { ++begin; } - } - - return token_ids; - }; - - static const size_t kWorkBatch = 1000000; - std::vector token_ids; - - if (text.size() < 2 * kWorkBatch) { - token_ids = worker(0, text.size()); - } else { - const size_t thread_count = std::min(thread_pool.maxThreads(), text.size() / kWorkBatch); - const size_t work_batch = text.size() / thread_count + 1; - std::vector> per_thread_token_ids(thread_count); - size_t work_begin = 0; - for (size_t thread_id = 0; thread_id < thread_count && work_begin < text.size(); thread_id++) { - size_t work_end = std::min(text.size(), work_begin + work_batch); - while (work_end < text.size() && !vkcom::is_space(text[work_end])) { - ++work_end; - } - thread_pool.submit([thread_id, work_begin, work_end, &per_thread_token_ids, &worker] { - per_thread_token_ids[thread_id] = worker(work_begin, work_end); - }); - work_begin = work_end; - } - - thread_pool.waitCompletion(); - - size_t token_count = 0; - for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { - token_count += per_thread_token_ids[thread_id].size(); - } - token_ids.resize(token_count); - work_begin = 0; - for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { - std::vector &segment = per_thread_token_ids[thread_id]; - if (!segment.empty()) { - std::memcpy(token_ids.data() + work_begin, segment.data(), segment.size() * sizeof(int)); - work_begin += segment.size(); - } - } } return token_ids; } -std::vector encode_word_piece(const char *text, size_t size, const WordPieceVocabulary &vocab) { - if (size == 0) { - return {}; - } - vkcom::ThreadPool thread_pool(0); - const std::vector text_utf8 = parseText(text, size, thread_pool); - return encode_word_piece_impl(text_utf8, vocab, thread_pool); -} - -} // namespace - -namespace vkcom::wordpiece { - -Status encode_as_ids(const std::string &text_path, - const std::string& vocab_path, std::vector *ids) { - const uint64_t kBatchLimit = 10 * 1024 * 1024; - - try { - std::string text; - Status status = fast_read_file_utf8(text_path, &text); - if (!status.ok()) { - return status; - } - std::vector vocab; - { - std::ifstream fin(vocab_path); - uint64_t processed = 0; - vocab = read_lines(fin, kBatchLimit, &processed); - } - WordPieceVocabulary word_piece_vocab(vocab); - *ids = encode_word_piece(text.data(), text.size(), word_piece_vocab); - return Status(); - } catch (const std::exception& ex) { - return Status(1, ex.what()); - } catch (...) { - return Status(1, "Unknown error"); - } -} - -Status encode_as_subwords(const std::string &text_path, - const std::string& vocab_path, - std::vector *subwords) { - const uint64_t kBatchLimit = 10 * 1024 * 1024; - - try { - std::string text; - Status status = fast_read_file_utf8(text_path, &text); - if (!status.ok()) { - return status; - } - std::vector vocab; - { - std::ifstream fin(vocab_path); - uint64_t processed = 0; - vocab = read_lines(fin, kBatchLimit, &processed); - } - WordPieceVocabulary word_piece_vocab(vocab); - std::vector ids = encode_word_piece(text.data(), text.size(), word_piece_vocab); - for (int id : ids) { - subwords->push_back(vocab[id]); - } - return Status(); - } catch (const std::exception& ex) { - return Status(1, ex.what()); - } catch (...) { - return Status(1, "Unknown error"); - } -} - -Status decode(const std::vector& ids, - const std::string& vocab_path, - std::vector *subwords, - const std::unordered_set *ignore_ids) { - const uint64_t kBatchLimit = 10 * 1024 * 1024; - - try { - std::vector vocab; - { - std::ifstream fin(vocab_path); - uint64_t processed = 0; - vocab = read_lines(fin, kBatchLimit, &processed); - } - for (int id : ids) { - if (!ignore_ids || ignore_ids->count(id) == 0) { - subwords->push_back(vocab.at(id)); - } - } - return Status(); - } catch (const std::exception& ex) { - return Status(1, ex.what()); - } catch (...) { - return Status(1, "Unknown error"); - } -} - } // namespace vkcom::wordpiece diff --git a/youtokentome/cpp/wordpiece.h b/youtokentome/cpp/wordpiece.h index 1a953d0..be83911 100644 --- a/youtokentome/cpp/wordpiece.h +++ b/youtokentome/cpp/wordpiece.h @@ -4,20 +4,69 @@ #include #include +#include "third_party/thread_pool/thread_pool.h" #include "utils.h" namespace vkcom::wordpiece { -Status encode_as_ids(const std::string &text_path, - const std::string& vocab_path, std::vector *ids); +struct WordPieceToken { + explicit WordPieceToken(const std::string &encoded_word); -Status encode_as_subwords(const std::string &text_path, - const std::string& vocab_path, - std::vector *subwords); + bool is_prefix; + bool is_special; + bool is_malformed; + std::vector word; +}; -Status decode(const std::vector& ids, - const std::string& vocab_path, - std::vector *subwords, - const std::unordered_set *ignore_ids); +struct WordPieceVocabulary { + explicit WordPieceVocabulary(const std::vector& words); + + std::vector tokens; + vkcom::SpecialTokens special_tokens; + size_t max_token_len = 0; + +private: + void update_special_tokens(const std::string& word, int token_id); +}; + +class Encoder { +public: + explicit Encoder(const std::string &vocab_path, int n_threads); + + explicit Encoder(std::vector vocab, int n_threads); + + Status encode_as_ids(const std::string &text_path, std::vector *ids) const; + + Status encode_as_subwords(const std::string &text_path, + std::vector *subwords) const; + + Status decode(const std::vector &ids, + std::vector *subwords, + const std::unordered_set *ignore_ids) const; + + Status id_to_subword(int id, std::string *subword) const; + + int subword_to_id(const std::string &token) const; + +private: + static const uint64_t kReadBatchLimit = 10 * 1024 * 1024; + + static bool is_word_prefix(const std::vector &text, size_t index); + + void build_word_maps(); + + std::vector encode_parallel(const std::vector &text); + std::vector encode_impl(const std::vector &text, size_t begin, size_t end) const; + + std::vector vocab_; + WordPieceVocabulary word_piece_vocab_; + + // TODO: flat_hash_map ? + using WordMap = std::unordered_map; + WordMap prefix_to_id_; // no ## in word prefix + WordMap suffix_to_id_; // ## in word prefix + + mutable ThreadPool thread_pool_; +}; } // namespace vkcom::wordpiece diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index 652c836..e149448 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -176,34 +176,38 @@ cdef class BPE: self.encoder.vocab_cli(verbose) cdef extern from "wordpiece.h" namespace "vkcom::wordpiece": - Status encode_as_ids(const string &text_path, const string &vocab_path, vector[int] *ids) + cdef cppclass Encoder: + Encoder(const string &vocab_path, int n_threads) - Status encode_as_subwords(const string &text_path, const string &vocab_path, vector[string] *subwords) + Status encode_as_ids(const string &text_path, vector[int] *ids) const - Status decode(const vector[int] &ids, const string &vocab_path, vector[string] *subwords, const unordered_set[int] *ignore_ids) + Status encode_as_subwords(const string &text_path, vector[string] *subwords) const + + Status decode(const vector[int] &ids, const string &vocab_path, vector[string] *subwords, const unordered_set[int] *ignore_ids) const + + Status id_to_subword(int id, string *subword) const + int subword_to_id(const string &subword) const cdef class WordPiece: - cdef string vocab_path - cdef int n_threads + cdef Encoder *encoder def __dealloc__(self): - pass + del self.encoder def __init__(self, vocab_path, n_threads=0): - self.vocab_path = vocab_path.encode() - self.n_threads = n_threads + self.encoder = new Encoder(vocab_path.encode(), n_threads) def encode(self, text_path, output_type): cdef Status status cdef vector[int] ids cdef vector[string] subwords if output_type == 'id': - status = encode_as_ids(text_path.encode(), self.vocab_path, &ids) + status = self.encoder.encode_as_ids(text_path.encode(), &ids) if status.code != 0: raise ValueError(status.message.decode()) return ids elif output_type == 'subword': - status = encode_as_subwords(text_path.encode(), self.vocab_path, &subwords) + status = self.encoder.encode_as_subwords(text_path.encode(), &subwords) if status.code != 0: raise ValueError(status.message.decode()) return subwords @@ -215,7 +219,7 @@ cdef class WordPiece: ignore_ids = set() cdef unordered_set[int] c_ignore_ids = unordered_set[int](ignore_ids) cdef vector[string] subwords - cdef Status status = decode(ids, self.vocab_path, &subwords, &c_ignore_ids) + cdef Status status = self.encoder.decode(ids, &subwords, &c_ignore_ids) if status.code != 0: raise ValueError(status.message.decode()) return subwords From 4228b285831dc2d720d448083983d16e1bbd0417 Mon Sep 17 00:00:00 2001 From: Gleb Koveshnikov Date: Sat, 8 Apr 2023 20:42:28 +0400 Subject: [PATCH 09/10] format --- youtokentome/cpp/wordpiece.cpp | 210 ++++++++++++++++----------------- youtokentome/cpp/wordpiece.h | 15 ++- 2 files changed, 112 insertions(+), 113 deletions(-) diff --git a/youtokentome/cpp/wordpiece.cpp b/youtokentome/cpp/wordpiece.cpp index 3682dc2..e8f340d 100644 --- a/youtokentome/cpp/wordpiece.cpp +++ b/youtokentome/cpp/wordpiece.cpp @@ -36,34 +36,32 @@ std::vector parse_text(const char *text, size_t size, vkcom::ThreadPoo std::vector> per_thread_text_utf8(thread_count); size_t work_start = 0; for (size_t thread_id = 0; thread_id < thread_count && work_start < size; thread_id++) { - size_t work_end = std::min(size, work_start + work_batch); - while (work_end < size && !vkcom::check_symbol_start(text[work_end])) { - ++work_end; - } - thread_pool.submit([thread_id, work_start, work_end, text, &per_thread_text_utf8] { - const char *begin = text + work_start; - const size_t len = work_end - work_start; - per_thread_text_utf8[thread_id] = vkcom::decode_utf8(begin, begin + len); - }); - work_start = work_end; + size_t work_end = std::min(size, work_start + work_batch); + while (work_end < size && !vkcom::check_symbol_start(text[work_end])) { + ++work_end; } + thread_pool.submit([thread_id, work_start, work_end, text, &per_thread_text_utf8] { + const char *begin = text + work_start; + const size_t len = work_end - work_start; + per_thread_text_utf8[thread_id] = vkcom::decode_utf8(begin, begin + len); + }); + work_start = work_end; + } - thread_pool.waitCompletion(); - size_t text_utf8_size = 0; - for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { - text_utf8_size += per_thread_text_utf8[thread_id].size(); + thread_pool.waitCompletion(); + size_t text_utf8_size = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + text_utf8_size += per_thread_text_utf8[thread_id].size(); + } + std::vector text_utf8(text_utf8_size); + text_utf8.resize(text_utf8_size); + work_start = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + std::vector &segment = per_thread_text_utf8[thread_id]; + if (!segment.empty()) { + std::memcpy(text_utf8.data() + work_start, segment.data(), segment.size() * sizeof(uint32_t)); + work_start += segment.size(); } - std::vector text_utf8(text_utf8_size); - text_utf8.resize(text_utf8_size); - work_start = 0; - for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { - std::vector &segment = per_thread_text_utf8[thread_id]; - if (!segment.empty()) { - std::memcpy(text_utf8.data() + work_start, - segment.data(), - segment.size() * sizeof(uint32_t)); - work_start += segment.size(); - } } return text_utf8; @@ -74,61 +72,62 @@ std::vector parse_text(const char *text, size_t size, vkcom::ThreadPoo namespace vkcom::wordpiece { WordPieceToken::WordPieceToken(const std::string &encoded_word) - : is_prefix(true), is_special(false), is_malformed(false), - word(vkcom::decode_utf8(encoded_word)) { + : is_prefix(true), is_special(false), is_malformed(false), word(vkcom::decode_utf8(encoded_word)) { if (is_suffix_vocab(word)) { - is_prefix = false; - word.erase(word.begin(), word.begin() + 2); - } else if (is_special_token(word)) { - is_special = true; - } + is_prefix = false; + word.erase(word.begin(), word.begin() + 2); + } else if (is_special_token(word)) { + is_special = true; + } - bool all_punctuation = true; - for (uint32_t code_point : word) { - if (code_point == vkcom::INVALID_UNICODE) { - is_malformed = true; - } - if (!vkcom::is_punctuation(code_point) && !vkcom::is_space(code_point)) { - all_punctuation = false; - } - } - if (word.empty()) { - throw std::runtime_error("Vocab word is empty"); - } - if (is_malformed || (all_punctuation && word.size() > 1)) { + bool all_punctuation = true; + for (uint32_t code_point : word) { + if (code_point == vkcom::INVALID_UNICODE) { is_malformed = true; - std::cerr << "Vocab word is malformed: " << encoded_word << std::endl; } + if (!vkcom::is_punctuation(code_point) && !vkcom::is_space(code_point)) { + all_punctuation = false; + } + } + if (word.empty()) { + throw std::runtime_error("Vocab word is empty"); + } + if (is_malformed || (all_punctuation && word.size() > 1)) { + is_malformed = true; + std::cerr << "Vocab word is malformed: " << encoded_word << std::endl; + } } -WordPieceVocabulary::WordPieceVocabulary(const std::vector& words) { - tokens.reserve(words.size()); - int token_id = 0; - max_token_len = 0; - for (const std::string& word : words) { - update_special_tokens(word, token_id); - WordPieceToken token(word); - max_token_len = std::max(max_token_len, token.word.size()); - tokens.push_back(std::move(token)); - ++token_id; - } +WordPieceVocabulary::WordPieceVocabulary(const std::vector &words) { + tokens.reserve(words.size()); + int token_id = 0; + max_token_len = 0; + for (const std::string &word : words) { + update_special_tokens(word, token_id); + WordPieceToken token(word); + max_token_len = std::max(max_token_len, token.word.size()); + tokens.push_back(std::move(token)); + ++token_id; + } } -void WordPieceVocabulary::update_special_tokens(const std::string& word, int token_id) { - if (word == UNK_TOKEN) { - special_tokens.unk_id = token_id; - } else if (word == PAD_TOKEN) { - special_tokens.pad_id = token_id; - } else if (word == BOS_TOKEN) { - special_tokens.bos_id = token_id; - } else if (word == EOS_TOKEN) { - special_tokens.eos_id = token_id; - } +void WordPieceVocabulary::update_special_tokens(const std::string &word, int token_id) { + if (word == UNK_TOKEN) { + special_tokens.unk_id = token_id; + } else if (word == PAD_TOKEN) { + special_tokens.pad_id = token_id; + } else if (word == BOS_TOKEN) { + special_tokens.bos_id = token_id; + } else if (word == EOS_TOKEN) { + special_tokens.eos_id = token_id; + } } -Encoder::Encoder(const std::string &vocab_path, int n_threads) : Encoder(read_lines(std::ifstream(vocab_path)), n_threads) {} +Encoder::Encoder(const std::string &vocab_path, int n_threads) + : Encoder(read_lines(std::ifstream(vocab_path)), n_threads) {} -Encoder::Encoder(std::vector vocab, int n_threads) : vocab_(std::move(vocab)), word_piece_vocab_(vocab_), thread_pool_(n_threads) { +Encoder::Encoder(std::vector vocab, int n_threads) + : vocab_(std::move(vocab)), word_piece_vocab_(vocab_), thread_pool_(n_threads) { build_word_maps(); } @@ -189,7 +188,7 @@ Status Encoder::decode(const std::vector &ids, bool Encoder::is_word_prefix(const std::vector &text, size_t index) { return index == 0 || vkcom::is_spacing_char(text[index]) - || vkcom::is_spacing_char(text[index - 1]); + || vkcom::is_spacing_char(text[index - 1]); } void Encoder::build_word_maps() { @@ -245,7 +244,8 @@ std::vector Encoder::encode_parallel(const std::vector &text) con return token_ids; } -std::vector Encoder::encode_impl(const std::vector &text, size_t begin, size_t end) const { +std::vector +Encoder::encode_impl(const std::vector &text, size_t begin, size_t end) const { size_t max_len = std::min(word_piece_vocab_.max_token_len, end - begin); const int unk_token_id = vocab.special_tokens.unk_id; @@ -260,47 +260,47 @@ std::vector Encoder::encode_impl(const std::vector &text, size_t while (begin != end) { size_t word_len = 1; - if (!vkcom::is_punctuation(text[begin])) { - while (word_len < std::min(max_len, end - begin) - && !vkcom::is_spacing_char(text[begin + word_len])) { - ++word_len; - } + if (!vkcom::is_punctuation(text[begin])) { + while (word_len < std::min(max_len, end - begin) + && !vkcom::is_spacing_char(text[begin + word_len])) { + ++word_len; } + } - const uint32_t *segment_begin = text.data() + static_cast(begin); - const uint32_t *segment_end = segment_begin + static_cast(word_len); - const WordMap *word_to_id = is_word_prefix(begin) ? &prefix_to_id_ : &suffix_to_id_; - - vkcom::VectorSegmentBuilder segment(segment_begin, segment_end); - while (!segment.empty()) { - auto it = word_to_id->find(segment.finish()); - if (it != word_to_id->end()) { - ++tokens_since_prefix; - token_ids.push_back(it->second); - begin += segment.size(); - break; - } else { - segment.pop_back(); - } + const uint32_t *segment_begin = text.data() + static_cast(begin); + const uint32_t *segment_end = segment_begin + static_cast(word_len); + const WordMap *word_to_id = is_word_prefix(begin) ? &prefix_to_id_ : &suffix_to_id_; + + vkcom::VectorSegmentBuilder segment(segment_begin, segment_end); + while (!segment.empty()) { + auto it = word_to_id->find(segment.finish()); + if (it != word_to_id->end()) { + ++tokens_since_prefix; + token_ids.push_back(it->second); + begin += segment.size(); + break; + } else { + segment.pop_back(); } + } - if (segment.empty()) { - while (tokens_since_prefix > 0) { - token_ids.pop_back(); - --tokens_since_prefix; - } - token_ids.push_back(unk_token_id); - begin += word_len; - while (begin != end && !is_word_prefix(begin)) { - ++begin; - } - } else if (begin != end && is_word_prefix(begin)) { - tokens_since_prefix = 0; + if (segment.empty()) { + while (tokens_since_prefix > 0) { + token_ids.pop_back(); + --tokens_since_prefix; } - - while (begin != end && vkcom::is_space(text[begin])) { + token_ids.push_back(unk_token_id); + begin += word_len; + while (begin != end && !is_word_prefix(begin)) { ++begin; } + } else if (begin != end && is_word_prefix(begin)) { + tokens_since_prefix = 0; + } + + while (begin != end && vkcom::is_space(text[begin])) { + ++begin; + } } return token_ids; diff --git a/youtokentome/cpp/wordpiece.h b/youtokentome/cpp/wordpiece.h index be83911..b1cd2ae 100644 --- a/youtokentome/cpp/wordpiece.h +++ b/youtokentome/cpp/wordpiece.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include "third_party/thread_pool/thread_pool.h" #include "utils.h" @@ -19,26 +19,25 @@ struct WordPieceToken { }; struct WordPieceVocabulary { - explicit WordPieceVocabulary(const std::vector& words); + explicit WordPieceVocabulary(const std::vector &words); std::vector tokens; vkcom::SpecialTokens special_tokens; size_t max_token_len = 0; -private: - void update_special_tokens(const std::string& word, int token_id); + private: + void update_special_tokens(const std::string &word, int token_id); }; class Encoder { -public: + public: explicit Encoder(const std::string &vocab_path, int n_threads); explicit Encoder(std::vector vocab, int n_threads); Status encode_as_ids(const std::string &text_path, std::vector *ids) const; - Status encode_as_subwords(const std::string &text_path, - std::vector *subwords) const; + Status encode_as_subwords(const std::string &text_path, std::vector *subwords) const; Status decode(const std::vector &ids, std::vector *subwords, @@ -48,7 +47,7 @@ class Encoder { int subword_to_id(const std::string &token) const; -private: + private: static const uint64_t kReadBatchLimit = 10 * 1024 * 1024; static bool is_word_prefix(const std::vector &text, size_t index); From aa608732f34d792d1c40196e2e028f3ce46a4d24 Mon Sep 17 00:00:00 2001 From: Gleb Koveshnikov Date: Tue, 11 Apr 2023 03:19:26 +0800 Subject: [PATCH 10/10] fix build --- .gitignore | 3 +- tests/unit_tests/wordpiece/stress_test.cpp | 120 +++++++++++---------- tests/unit_tests/wordpiece/test_stress.py | 2 +- youtokentome/cpp/utils.cpp | 9 ++ youtokentome/cpp/utils.h | 2 + youtokentome/cpp/wordpiece.cpp | 40 ++++--- youtokentome/cpp/wordpiece.h | 5 +- youtokentome/cpp/yttm.pyx | 4 +- 8 files changed, 110 insertions(+), 75 deletions(-) diff --git a/.gitignore b/.gitignore index 86d70e3..be4ce66 100644 --- a/.gitignore +++ b/.gitignore @@ -57,7 +57,8 @@ coverage.xml *.txt *.yttm artifacts/ -stress +bpe_stress +wordpiece_stress # Translations *.mo diff --git a/tests/unit_tests/wordpiece/stress_test.cpp b/tests/unit_tests/wordpiece/stress_test.cpp index 6253a83..a9a7fd7 100644 --- a/tests/unit_tests/wordpiece/stress_test.cpp +++ b/tests/unit_tests/wordpiece/stress_test.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -7,9 +8,9 @@ #include #include -#include "../../youtokentome/cpp/wordpiece.h" +#include "../../../youtokentome/cpp/wordpiece.h" -namespace vkcom { +using namespace vkcom; struct TestCase { std::string text; @@ -31,9 +32,9 @@ void dump_test_case(const TestCase &test_case) { std::ofstream fout("stress.txt"); fout << test_case.text; } - dump_vector("vocab.txt", text.vocab, '\n'); - dump_vector("anwer_encoded.txt", text.anwer_encoded, ' '); - dump_vector("answer_decoded.txt", text.answer_decoded, ' '); + dump_vector("vocab.txt", test_case.vocab, '\n'); + dump_vector("anwer_encoded.txt", test_case.answer_encoded, ' '); + dump_vector("answer_decoded.txt", test_case.answer_decoded, ' '); } void check(const TestCase &test_case, const std::vector &encoded, const std::vector &decoded) { @@ -44,7 +45,10 @@ void check(const TestCase &test_case, const std::vector &encoded, const std } std::string get_random_string(std::mt19937 &rnd, size_t string_length) { - static constexpr std::string_view kAllChars = "abcdefghijklmnopqrstuvwxyz"; + static const std::string kAllChars = "abcdefghijklmnopqrstuvwxyz"; + if (string_length == 0) { + throw std::runtime_error("string_length cannot be 0"); + } std::string result; result.reserve(string_length); while (string_length > 0) { @@ -59,39 +63,40 @@ TestCase generate_test_case(size_t text_len, size_t parts) { std::mt19937 rnd(17); std::string text; text.reserve(text_len + parts); - std::uniform_int_distribution word_len(0ul, text_len / 2); + std::uniform_int_distribution word_len(1ul, std::max(2 * text_len / parts, 3ul)); std::unordered_map vocab_map; std::vector answer_encoded; std::vector answer_decoded; - text.reserve(parts); + answer_encoded.reserve(parts); + answer_decoded.reserve(parts); - for (size_t i = 0; i < parts; i++) { + for (size_t i = 0; i < parts && text.size() < text.capacity(); i++) { const size_t vocab_size = vocab_map.size(); if (i + 1 == parts) { - size_t leftover = text.capacity() > text.size() ? text.capacity() - text.size() : 0; + size_t leftover = text.capacity() - text.size(); std::string word = get_random_string(rnd, leftover); if (vocab_map[word] == 0) { - vocab_map[word] = static_cast(vocab_size); + vocab_map[word] = static_cast(vocab_size) + 1; } text.append(word); - answer_encoded.push_back(vocab_map[word]); + answer_encoded.push_back(vocab_map[word] - 1); answer_decoded.push_back(std::move(word)); - } else if (text % 10 == 0) { - std::uniform_int_distribution rnd_word(0ul, vocab_size); + } else if (i > 0 && i % 10 == 0) { + std::uniform_int_distribution rnd_word(0ul, vocab_size - 1); auto it = std::next(vocab_map.begin(), rnd_word(rnd)); text.append(it->first); - text.append(' '); - answer_encoded.push_back(it->second); + text.push_back(' '); + answer_encoded.push_back(it->second - 1); answer_decoded.push_back(it->first); } else { std::string word = get_random_string(rnd, word_len(rnd)); if (vocab_map[word] == 0) { - vocab_map[word] = static_cast(vocab_size); + vocab_map[word] = static_cast(vocab_size) + 1; } text.append(word); - text.append(' '); - answer_encoded.push_back(vocab_map[word]); + text.push_back(' '); + answer_encoded.push_back(vocab_map[word] - 1); answer_decoded.push_back(std::move(word)); } } @@ -99,7 +104,7 @@ TestCase generate_test_case(size_t text_len, size_t parts) { std::vector vocab; vocab.resize(vocab_map.size()); for (auto it = vocab_map.begin(); it != vocab_map.end(); it++) { - vocab[it->second] = it->first; + vocab[it->second - 1] = it->first; } return TestCase{std::move(text), std::move(vocab), std::move(answer_encoded), std::move(answer_decoded)}; } @@ -114,28 +119,30 @@ void test_stress(size_t text_len_from, for (size_t parts = std::min(text_len, parts_from); parts <= std::min(text_len, parts_to); parts++) { - for (int i = 0; i < 3; i++) { - TestCase test_case = generate_test_case(text_len, parts, positive); - std::cout << "running stress, text_len " << test_case.text.size() << ", vocab_size " - << test_case.vocab.size() << std::endl; - - Status status; - std::vector encoded; - wordpiece::Encoder encoder(test_case.vocab, n_threads); - status = encoder.encode_as_ids(test_case.text, &encoded); - if (!status.ok()) { - dump_test_case(test_case); - throw std::runtime_error("encode_as_ids failed, test_case dumped"); - } - std::vector decoded; - status = encoder.encode_as_subwords(test_case.text, test_case.vocab, &decoded); - if (!status.ok()) { - dump_test_case(test_case); - throw std::runtime_error("encode_as_subwords failed, test_case dumped"); - } + const std::string text_filename("stress.txt"); + TestCase test_case = generate_test_case(text_len, parts); + std::cout << "running stress, text_len " << test_case.text.size() << ' ' << text_len << ", vocab_size " + << test_case.vocab.size() << std::endl; + { + std::ofstream fout(text_filename); + fout << test_case.text; + } - check(test_case, encoded, decoded); + Status status; + std::vector encoded; + wordpiece::Encoder encoder(test_case.vocab, n_threads); + status = encoder.encode_as_ids(text_filename, &encoded); + if (!status.ok()) { + dump_test_case(test_case); + throw std::runtime_error("encode_as_ids failed, test_case dumped: " + status.error_message()); + } + std::vector decoded; + status = encoder.encode_as_subwords(text_filename, &decoded); + if (!status.ok()) { + dump_test_case(test_case); + throw std::runtime_error("encode_as_subwords failed, test_case dumped: " + status.error_message()); } + check(test_case, encoded, decoded); } } } @@ -146,26 +153,30 @@ void run_small(int n_threads) { } void run_large(int n_threads) { - test_stress(100'000, - 1'000'000, - 400'000, - kWordPieceVocabSize, - kWordPieceVocabSize, - n_threads); - test_stress(10'000'000, - 10'000'000, - 200'000, - kWordPieceVocabSize, - kWordPieceVocabSize, + test_stress(100000, + 1000000, + 400000, + 30000, + 30000, + n_threads); + test_stress(10000000, + 10000000, + 200000, + 30000, + 30000, n_threads); } int main(int argc, char **argv) { - if (argc == 2 && argv[1] == "small") { + if (argc != 2) { + assert(false); + } + std::string mode = argv[1]; + if (argc == 2 && mode == "small") { run_small(1); - } else if (argc == 2 && argv[1] == "large") { + } else if (argc == 2 && mode == "large") { run_large(1); - } else if (argc == 2 && argv[1] == "parallel") { + } else if (argc == 2 && mode == "parallel") { run_small(0); run_large(0); } else { @@ -173,4 +184,3 @@ int main(int argc, char **argv) { } } -} // namespace vkcom diff --git a/tests/unit_tests/wordpiece/test_stress.py b/tests/unit_tests/wordpiece/test_stress.py index c28e2be..520dbd8 100644 --- a/tests/unit_tests/wordpiece/test_stress.py +++ b/tests/unit_tests/wordpiece/test_stress.py @@ -9,7 +9,7 @@ def compile_test(): if tests_compiled: return build_files = ["wordpiece.cpp", "utils.cpp", "utf8.cpp"] - files = ["../../youtokentome/cpp/" + file_name for file_name in build_files] + files = ["../../../youtokentome/cpp/" + file_name for file_name in build_files] files.append("stress_test.cpp") print("compiling wordpiece stress test ...") diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 2bf42df..45502d0 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -43,6 +43,15 @@ uint64_t SpecialTokens::n_special_tokens() const { SpecialTokens::SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id) : pad_id(pad_id), unk_id(unk_id), bos_id(bos_id), eos_id(eos_id) {} +std::vector read_all_lines(std::istream& stream) { + std::vector sentences; + std::string s; + while (std::getline(stream, s)) { + sentences.push_back(std::move(s)); + } + return sentences; +} + std::vector read_lines(std::istream& stream, uint64_t batch_limit, uint64_t *processed) { std::vector sentences; std::string s; diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index 1b38d5d..687dbf7 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -45,6 +45,8 @@ struct SpecialTokens { uint64_t n_special_tokens() const; }; +std::vector read_all_lines(std::istream& stream); + std::vector read_lines(std::istream& stream, uint64_t batch_limit, uint64_t *processed); Status fast_read_file_utf8(const std::string &file_name, std::string *file_content); diff --git a/youtokentome/cpp/wordpiece.cpp b/youtokentome/cpp/wordpiece.cpp index e8f340d..da21313 100644 --- a/youtokentome/cpp/wordpiece.cpp +++ b/youtokentome/cpp/wordpiece.cpp @@ -7,6 +7,8 @@ #include "utf8.h" +namespace vkcom::wordpiece { + namespace { const std::string UNK_TOKEN = "[UNK]"; @@ -67,9 +69,12 @@ std::vector parse_text(const char *text, size_t size, vkcom::ThreadPoo return text_utf8; } -} // namespace +std::vector read_lines_helper(const std::string &filename) { + std::ifstream fin(filename); + return read_all_lines(fin); +} -namespace vkcom::wordpiece { +} // namespace WordPieceToken::WordPieceToken(const std::string &encoded_word) : is_prefix(true), is_special(false), is_malformed(false), word(vkcom::decode_utf8(encoded_word)) { @@ -124,7 +129,7 @@ void WordPieceVocabulary::update_special_tokens(const std::string &word, int tok } Encoder::Encoder(const std::string &vocab_path, int n_threads) - : Encoder(read_lines(std::ifstream(vocab_path)), n_threads) {} + : Encoder(read_lines_helper(vocab_path), n_threads) {} Encoder::Encoder(std::vector vocab, int n_threads) : vocab_(std::move(vocab)), word_piece_vocab_(vocab_), thread_pool_(n_threads) { @@ -138,7 +143,7 @@ Status Encoder::encode_as_ids(const std::string &text_path, std::vector *id if (!status.ok()) { return status; } - const std::vector text = parse_text(text_str, text_str.size(), thread_pool_); + const std::vector text = parse_text(text_str.data(), text_str.size(), thread_pool_); *ids = encode_parallel(text); return Status(); } catch (const std::exception &ex) { @@ -156,10 +161,10 @@ Status Encoder::encode_as_subwords(const std::string &text_path, if (!status.ok()) { return status; } - const std::vector text = parse_text(text_str, text_str.size(), thread_pool_); + const std::vector text = parse_text(text_str.data(), text_str.size(), thread_pool_); std::vector ids = encode_parallel(text); for (int id : ids) { - subwords->push_back(vocab[id]); + subwords->push_back(vocab_[id]); } return Status(); } catch (const std::exception &ex) { @@ -192,8 +197,8 @@ bool Encoder::is_word_prefix(const std::vector &text, size_t index) { } void Encoder::build_word_maps() { - for (size_t i = 0; i < vocab.tokens.size(); i++) { - const auto &token = vocab.tokens[i]; + for (size_t i = 0; i < word_piece_vocab_.tokens.size(); i++) { + const auto &token = word_piece_vocab_.tokens[i]; if (token.is_special || token.is_malformed) { continue; } @@ -219,7 +224,7 @@ std::vector Encoder::encode_parallel(const std::vector &text) con while (work_end < text.size() && !vkcom::is_space(text[work_end])) { ++work_end; } - thread_pool_.submit([thread_id, work_begin, work_end, &per_thread_token_ids, &worker, &text] { + thread_pool_.submit([this, thread_id, work_begin, work_end, &per_thread_token_ids, &text] { per_thread_token_ids[thread_id] = encode_impl(text, work_begin, work_end); }); work_begin = work_end; @@ -247,7 +252,16 @@ std::vector Encoder::encode_parallel(const std::vector &text) con std::vector Encoder::encode_impl(const std::vector &text, size_t begin, size_t end) const { size_t max_len = std::min(word_piece_vocab_.max_token_len, end - begin); - const int unk_token_id = vocab.special_tokens.unk_id; + if (begin == end) { + return {}; + } + if (word_piece_vocab_.tokens.empty()) { + throw std::runtime_error("abc"); + } + if (max_len == 0) { + throw std::runtime_error("her"); + } + const int unk_token_id = word_piece_vocab_.special_tokens.unk_id; std::vector token_ids; token_ids.reserve((end - begin) / max_len + 1); @@ -269,7 +283,7 @@ Encoder::encode_impl(const std::vector &text, size_t begin, size_t end const uint32_t *segment_begin = text.data() + static_cast(begin); const uint32_t *segment_end = segment_begin + static_cast(word_len); - const WordMap *word_to_id = is_word_prefix(begin) ? &prefix_to_id_ : &suffix_to_id_; + const WordMap *word_to_id = is_word_prefix(text, begin) ? &prefix_to_id_ : &suffix_to_id_; vkcom::VectorSegmentBuilder segment(segment_begin, segment_end); while (!segment.empty()) { @@ -291,10 +305,10 @@ Encoder::encode_impl(const std::vector &text, size_t begin, size_t end } token_ids.push_back(unk_token_id); begin += word_len; - while (begin != end && !is_word_prefix(begin)) { + while (begin != end && !is_word_prefix(text, begin)) { ++begin; } - } else if (begin != end && is_word_prefix(begin)) { + } else if (begin != end && is_word_prefix(text, begin)) { tokens_since_prefix = 0; } diff --git a/youtokentome/cpp/wordpiece.h b/youtokentome/cpp/wordpiece.h index b1cd2ae..8ff7e34 100644 --- a/youtokentome/cpp/wordpiece.h +++ b/youtokentome/cpp/wordpiece.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -48,13 +49,11 @@ class Encoder { int subword_to_id(const std::string &token) const; private: - static const uint64_t kReadBatchLimit = 10 * 1024 * 1024; - static bool is_word_prefix(const std::vector &text, size_t index); void build_word_maps(); - std::vector encode_parallel(const std::vector &text); + std::vector encode_parallel(const std::vector &text) const; std::vector encode_impl(const std::vector &text, size_t begin, size_t end) const; std::vector vocab_; diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index e149448..f0a4f07 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -183,13 +183,13 @@ cdef extern from "wordpiece.h" namespace "vkcom::wordpiece": Status encode_as_subwords(const string &text_path, vector[string] *subwords) const - Status decode(const vector[int] &ids, const string &vocab_path, vector[string] *subwords, const unordered_set[int] *ignore_ids) const + Status decode(const vector[int] &ids, vector[string] *subwords, const unordered_set[int] *ignore_ids) const Status id_to_subword(int id, string *subword) const int subword_to_id(const string &subword) const cdef class WordPiece: - cdef Encoder *encoder + cdef Encoder *encoder def __dealloc__(self): del self.encoder