From 96a1c51d7429cb53342069215a0c06f05f070db0 Mon Sep 17 00:00:00 2001 From: Kay Faraday Date: Wed, 11 Jan 2023 06:34:35 +0000 Subject: [PATCH] wip gpt2 --- generators/gpt2.py | 32 ++++++++++++++++++++++++++++++++ requirements/gpt2.txt | 8 ++------ 2 files changed, 34 insertions(+), 6 deletions(-) create mode 100644 generators/gpt2.py diff --git a/generators/gpt2.py b/generators/gpt2.py new file mode 100644 index 0000000..83f3bb7 --- /dev/null +++ b/generators/gpt2.py @@ -0,0 +1,32 @@ +import pytorch_lightning.utilities +# hack until https://github.com/minimaxir/aitextgen/issues/200 is fixed +pytorch_lightning.utilities._TPU_AVAILABLE = False + +from aitextgen.TokenDataset import TokenDataset +from aitextgen.tokenizers import train_tokenizer +from aitextgen.utils import GPT2ConfigCPU +from aitextgen import aitextgen + +# The name of the downloaded Shakespeare text for training +file_name = "littlethief.txt" + +# Train a custom BPE Tokenizer on the downloaded text +# This will save one file: `aitextgen.tokenizer.json`, which contains the +# information needed to rebuild the tokenizer. +train_tokenizer(file_name) +tokenizer_file = "aitextgen.tokenizer.json" + +# GPT2ConfigCPU is a mini variant of GPT-2 optimized for CPU-training +# e.g. the # of input tokens here is 64 vs. 1024 for base GPT-2. +config = GPT2ConfigCPU() + +# Instantiate aitextgen using the created tokenizer and config +ai = aitextgen(tokenizer_file=tokenizer_file, config=config) + +# You can build datasets for training by creating TokenDatasets, +# which automatically processes the dataset with the appropriate size. +data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64) + +# Train the model! It will save pytorch_model.bin periodically and after completion to the `trained_model` folder. +# On a 2020 8-core iMac, this took ~25 minutes to run. +ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000) \ No newline at end of file diff --git a/requirements/gpt2.txt b/requirements/gpt2.txt index 552f331..f38327d 100644 --- a/requirements/gpt2.txt +++ b/requirements/gpt2.txt @@ -1,6 +1,2 @@ -fire ~= 0.1.3 -# this package claims perpetual backwards compatibility with the "re" module so we allow any version -regex -requests ~= 2.0 -tqdm ~= 4.0 -toposort ~= 1.0 +pytorch-lightning==1.7.7 +aitextgen