summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorartnitolog <53991623+artnitolog@users.noreply.github.com>2022-06-22 20:35:12 +0300
committerartnitolog <53991623+artnitolog@users.noreply.github.com>2022-06-22 20:35:12 +0300
commit56db7983b1a02114fce35d8850599c93f4b3079d (patch)
treeac004925eee1fccbc4ee1d2bc69eed0f6770c5b6
parent72f66b3fd3bfb51bf75eac4a7562e9533ced843b (diff)
add src
-rw-r--r--LICENSE202
-rw-r--r--docker/Dockerfile73
-rw-r--r--docker/build.sh1
-rw-r--r--docker/pull.sh2
-rw-r--r--docker/run.sh10
-rw-r--r--download/download.sh9
-rw-r--r--examples/example_cond_input.json3
-rw-r--r--examples/generate_conditional_greedy.sh55
-rw-r--r--examples/generate_conditional_sampling.sh57
-rw-r--r--examples/generate_interactive.sh51
-rw-r--r--examples/generate_unconditional.sh54
-rw-r--r--megatron_lm/LICENSE264
-rw-r--r--megatron_lm/MANIFEST.in2
-rw-r--r--megatron_lm/README.md568
-rw-r--r--megatron_lm/changes.md1
-rw-r--r--megatron_lm/curriculum_learning/README.md1
-rw-r--r--megatron_lm/curriculum_learning/ds_pretrain_gpt2.sh183
-rw-r--r--megatron_lm/curriculum_learning/ds_train.sh37
-rw-r--r--megatron_lm/curriculum_learning/ds_zero_stage_2_config_baseline.json31
-rw-r--r--megatron_lm/curriculum_learning/ds_zero_stage_2_config_curriculum_fixed_linear.json42
-rwxr-xr-xmegatron_lm/examples/ds_pretrain_gpt2-zero2.sh164
-rwxr-xr-xmegatron_lm/examples/ds_pretrain_gpt2-zero3.sh164
-rwxr-xr-xmegatron_lm/examples/ds_pretrain_gpt2.sh133
-rwxr-xr-xmegatron_lm/examples/ds_zero_stage_2_config.json32
-rwxr-xr-xmegatron_lm/examples/ds_zero_stage_3_config.json24
-rwxr-xr-xmegatron_lm/examples/ds_zero_stage_3_config_release.json29
-rwxr-xr-xmegatron_lm/examples/ds_zero_stage_infinity_config.json47
-rwxr-xr-xmegatron_lm/examples/evaluate_zeroshot_gpt2.sh38
-rwxr-xr-xmegatron_lm/examples/finetune_mnli_distributed.sh44
-rwxr-xr-xmegatron_lm/examples/finetune_race_distributed.sh47
-rwxr-xr-xmegatron_lm/examples/generate_text.sh25
-rwxr-xr-xmegatron_lm/examples/merge_mp_bert.sh18
-rwxr-xr-xmegatron_lm/examples/pretrain_bert.sh35
-rwxr-xr-xmegatron_lm/examples/pretrain_bert_distributed.sh44
-rwxr-xr-xmegatron_lm/examples/pretrain_gpt2.sh43
-rwxr-xr-xmegatron_lm/examples/pretrain_gpt2_distributed.sh52
-rw-r--r--megatron_lm/images/Makefile12
-rw-r--r--megatron_lm/images/cases.pngbin0 -> 11807 bytes
-rw-r--r--megatron_lm/images/scaling-dp.pngbin0 -> 13419 bytes
-rw-r--r--megatron_lm/images/scaling-mp.pngbin0 -> 22813 bytes
-rw-r--r--megatron_lm/images/tables.tex40
-rw-r--r--megatron_lm/megatron/__init__.py41
-rw-r--r--megatron_lm/megatron/arguments.py587
-rw-r--r--megatron_lm/megatron/checkpointing.py464
-rw-r--r--megatron_lm/megatron/data/Makefile9
-rw-r--r--megatron_lm/megatron/data/__init__.py1
-rw-r--r--megatron_lm/megatron/data/bert_dataset.py232
-rw-r--r--megatron_lm/megatron/data/dataset_utils.py503
-rw-r--r--megatron_lm/megatron/data/gpt2_dataset.py317
-rw-r--r--megatron_lm/megatron/data/helpers.cpp643
-rw-r--r--megatron_lm/megatron/data/ict_dataset.py140
-rw-r--r--megatron_lm/megatron/data/indexed_dataset.py570
-rw-r--r--megatron_lm/megatron/data/realm_dataset_utils.py201
-rw-r--r--megatron_lm/megatron/data/realm_index.py216
-rw-r--r--megatron_lm/megatron/data/samplers.py148
-rw-r--r--megatron_lm/megatron/data/test/test_indexed_dataset.py125
-rwxr-xr-xmegatron_lm/megatron/data/test/test_preprocess_data.sh10
-rw-r--r--megatron_lm/megatron/deprecated_data_utils/__init__.py141
-rw-r--r--megatron_lm/megatron/deprecated_data_utils/configure_data.py252
-rwxr-xr-xmegatron_lm/megatron/deprecated_data_utils/corpora.py61
-rwxr-xr-xmegatron_lm/megatron/deprecated_data_utils/datasets.py883
-rwxr-xr-xmegatron_lm/megatron/deprecated_data_utils/file_utils.py253
-rw-r--r--megatron_lm/megatron/deprecated_data_utils/lazy_loader.py202
-rw-r--r--megatron_lm/megatron/deprecated_data_utils/samplers.py143
-rw-r--r--megatron_lm/megatron/deprecated_data_utils/scripts/presplit_sentences_json.py27
-rw-r--r--megatron_lm/megatron/deprecated_data_utils/scripts/split_gpt2_json.py141
-rw-r--r--megatron_lm/megatron/deprecated_data_utils/scripts/split_json.py126
-rwxr-xr-xmegatron_lm/megatron/deprecated_data_utils/tf_dl.py129
-rwxr-xr-xmegatron_lm/megatron/deprecated_data_utils/tokenization.py922
-rw-r--r--megatron_lm/megatron/deprecated_data_utils/tokenization_gpt2.py319
-rwxr-xr-xmegatron_lm/megatron/deprecated_data_utils/wordpiece.py391
-rw-r--r--megatron_lm/megatron/fp16/__init__.py30
-rwxr-xr-xmegatron_lm/megatron/fp16/fp16.py651
-rw-r--r--megatron_lm/megatron/fp16/fp16util.py216
-rwxr-xr-xmegatron_lm/megatron/fp16/loss_scaler.py256
-rw-r--r--megatron_lm/megatron/fused_kernels/__init__.py100
-rw-r--r--megatron_lm/megatron/fused_kernels/scaled_masked_softmax.cpp74
-rw-r--r--megatron_lm/megatron/fused_kernels/scaled_masked_softmax.h452
-rw-r--r--megatron_lm/megatron/fused_kernels/scaled_masked_softmax_cuda.cu102
-rw-r--r--megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp69
-rw-r--r--megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h439
-rw-r--r--megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu89
-rw-r--r--megatron_lm/megatron/global_vars.py233
-rw-r--r--megatron_lm/megatron/indexer.py91
-rw-r--r--megatron_lm/megatron/initialize.py208
-rw-r--r--megatron_lm/megatron/learning_rates.py154
-rw-r--r--megatron_lm/megatron/memory.py145
-rwxr-xr-xmegatron_lm/megatron/model/__init__.py21
-rw-r--r--megatron_lm/megatron/model/bert_model.py196
-rw-r--r--megatron_lm/megatron/model/classification.py98
-rwxr-xr-xmegatron_lm/megatron/model/distributed.py112
-rw-r--r--megatron_lm/megatron/model/fused_bias_gelu.py60
-rw-r--r--megatron_lm/megatron/model/fused_softmax.py127
-rw-r--r--megatron_lm/megatron/model/gpt2_model.py125
-rw-r--r--megatron_lm/megatron/model/language_model.py503
-rw-r--r--megatron_lm/megatron/model/multiple_choice.py110
-rw-r--r--megatron_lm/megatron/model/realm_model.py204
-rw-r--r--megatron_lm/megatron/model/transformer.py1079
-rw-r--r--megatron_lm/megatron/model/utils.py83
-rw-r--r--megatron_lm/megatron/module.py31
-rw-r--r--megatron_lm/megatron/mpu/__init__.py53
-rw-r--r--megatron_lm/megatron/mpu/cross_entropy.py110
-rw-r--r--megatron_lm/megatron/mpu/data.py116
-rw-r--r--megatron_lm/megatron/mpu/grads.py127
-rw-r--r--megatron_lm/megatron/mpu/initialize.py162
-rw-r--r--megatron_lm/megatron/mpu/layers.py369
-rw-r--r--megatron_lm/megatron/mpu/mappings.py157
-rw-r--r--megatron_lm/megatron/mpu/random.py319
-rw-r--r--megatron_lm/megatron/mpu/tests/__init__.py0
-rw-r--r--megatron_lm/megatron/mpu/tests/commons.py83
-rw-r--r--megatron_lm/megatron/mpu/tests/test_cross_entropy.py108
-rw-r--r--megatron_lm/megatron/mpu/tests/test_data.py88
-rw-r--r--megatron_lm/megatron/mpu/tests/test_initialize.py95
-rw-r--r--megatron_lm/megatron/mpu/tests/test_layers.py530
-rw-r--r--megatron_lm/megatron/mpu/tests/test_random.py204
-rw-r--r--megatron_lm/megatron/mpu/utils.py70
-rw-r--r--megatron_lm/megatron/package_info.py30
-rw-r--r--megatron_lm/megatron/text_generation_utils.py397
-rw-r--r--megatron_lm/megatron/tokenizer/__init__.py17
-rw-r--r--megatron_lm/megatron/tokenizer/bert_tokenization.py402
-rw-r--r--megatron_lm/megatron/tokenizer/gpt2_tokenization.py321
-rw-r--r--megatron_lm/megatron/tokenizer/sp_tokenization.py78
-rw-r--r--megatron_lm/megatron/tokenizer/tokenizer.py280
-rw-r--r--megatron_lm/megatron/training.py685
-rw-r--r--megatron_lm/megatron/utils.py196
-rw-r--r--megatron_lm/pretrain_bert.py123
-rw-r--r--megatron_lm/pretrain_gpt2.py139
-rw-r--r--megatron_lm/pretrain_ict.py138
-rw-r--r--megatron_lm/requirements.txt5
-rw-r--r--megatron_lm/setup.py91
-rw-r--r--megatron_lm/tasks/data_utils.py118
-rw-r--r--megatron_lm/tasks/ensemble_classifier.py149
-rw-r--r--megatron_lm/tasks/eval_utils.py127
-rw-r--r--megatron_lm/tasks/finetune_utils.py259
-rw-r--r--megatron_lm/tasks/glue/data.py69
-rw-r--r--megatron_lm/tasks/glue/finetune.py90
-rw-r--r--megatron_lm/tasks/glue/mnli.py84
-rw-r--r--megatron_lm/tasks/glue/qqp.py101
-rw-r--r--megatron_lm/tasks/main.py69
-rw-r--r--megatron_lm/tasks/race/data.py131
-rw-r--r--megatron_lm/tasks/race/finetune.py63
-rw-r--r--megatron_lm/tasks/zeroshot_gpt2/datasets.py161
-rw-r--r--megatron_lm/tasks/zeroshot_gpt2/detokenizer.py80
-rw-r--r--megatron_lm/tasks/zeroshot_gpt2/evaluate.py195
-rw-r--r--megatron_lm/tools/create_doc_index.py30
-rw-r--r--megatron_lm/tools/generate_samples_gpt2.py104
-rw-r--r--megatron_lm/tools/linter.py36
-rw-r--r--megatron_lm/tools/merge_mp_partitions.py290
-rw-r--r--megatron_lm/tools/openwebtext/README.md46
-rw-r--r--megatron_lm/tools/openwebtext/blacklist_urls.py312
-rw-r--r--megatron_lm/tools/openwebtext/cleanup_dataset.py115
-rw-r--r--megatron_lm/tools/openwebtext/find_duplicates.py100
-rw-r--r--megatron_lm/tools/openwebtext/group_duplicates_url.py90
-rw-r--r--megatron_lm/tools/openwebtext/merge_jsons.py55
-rw-r--r--megatron_lm/tools/openwebtext/remove_group_duplicates.py69
-rw-r--r--megatron_lm/tools/preprocess_data.py200
156 files changed, 25528 insertions, 0 deletions
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..dde066c
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+Copyright 2022 YANDEX LLC
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+Copyright 2022 YANDEX LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. \ No newline at end of file
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000..da657a8
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,73 @@
+FROM nvcr.io/nvidia/pytorch:20.11-py3
+
+# Though already installed, pybind11 could not be found during the build. Reinstalling fixes the problem.
+RUN python3 -mpip install pybind11==2.6.2 -U
+
+##############################################################################
+# Updating NCCL to 2.10.3-1
+##############################################################################
+RUN wget --quiet --tries=5 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/libnccl2_2.10.3-1+cuda11.0_amd64.deb
+RUN apt install --yes --fix-broken --no-install-recommends ./libnccl2_2.10.3-1+cuda11.0_amd64.deb
+RUN rm -f ./libnccl2_2.10.3-1+cuda11.0_amd64.deb
+RUN wget --quiet --tries=5 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/libnccl-dev_2.10.3-1+cuda11.0_amd64.deb
+RUN apt install --yes --fix-broken --no-install-recommends ./libnccl-dev_2.10.3-1+cuda11.0_amd64.deb
+RUN rm -f ./libnccl-dev_2.10.3-1+cuda11.0_amd64.deb
+
+##############################################################################
+# Installation/Basic Utilities
+##############################################################################
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ software-properties-common build-essential autotools-dev \
+ nfs-common pdsh \
+ cmake g++ gcc \
+ curl wget vim tmux emacs less unzip \
+ htop iftop iotop ca-certificates openssh-client openssh-server \
+ rsync iputils-ping net-tools sudo \
+ llvm-9-dev
+
+##############################################################################
+# Installation Latest Git
+##############################################################################
+RUN add-apt-repository ppa:git-core/ppa -y && \
+ apt-get update && \
+ apt-get install -y git && \
+ git --version
+
+##############################################################################
+# Reinstall PyTorch to 1.10.0 from Oct, 2021; use local NCCL
+##############################################################################
+RUN git clone --recursive https://github.com/pytorch/pytorch.git && \
+ cd pytorch && \
+ git checkout v1.10.0 && \
+ git submodule sync && \
+ git submodule update --init --recursive && \
+ export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} && \
+ USE_SYSTEM_NCCL=1 python3 setup.py install && \
+ cd .. && rm -rf pytorch
+
+##############################################################################
+# apex a651e2c24ecf97cbf367fd3f330df36760e1c597
+##############################################################################
+RUN git clone --recursive https://github.com/NVIDIA/apex && \
+ cd apex && \
+ git reset --hard a651e2c24ecf97cbf367fd3f330df36760e1c597 && \
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . && \
+ cd .. && rm -rf apex
+
+##############################################################################
+# Other packages
+##############################################################################
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ libsndfile-dev \
+ libcupti-dev \
+ libjpeg-dev \
+ libpng-dev \
+ screen \
+ libaio-dev
+
+##############################################################################
+# Installing DeepSpeed
+##############################################################################
+RUN pip install deepspeed==0.6.5 \ No newline at end of file
diff --git a/docker/build.sh b/docker/build.sh
new file mode 100644
index 0000000..5c905f6
--- /dev/null
+++ b/docker/build.sh
@@ -0,0 +1 @@
+docker build -t yalm-cuda11-ds:1.0 --network host .
diff --git a/docker/pull.sh b/docker/pull.sh
new file mode 100644
index 0000000..2fd5d04
--- /dev/null
+++ b/docker/pull.sh
@@ -0,0 +1,2 @@
+docker pull yandex/yalm-cuda11-ds:1.0
+docker tag yandex/yalm-cuda11-ds:1.0 yalm-cuda11-ds:1.0
diff --git a/docker/run.sh b/docker/run.sh
new file mode 100644
index 0000000..1206e54
--- /dev/null
+++ b/docker/run.sh
@@ -0,0 +1,10 @@
+IMAGE_NAME=yalm-cuda11-ds:1.0
+
+docker run \
+--mount type=bind,source=/dev/shm,target=/dev/shm \
+-v $HOME:$HOME \
+--name "yalm-cuda11-ds-${USER}" \
+-v ${SSH_AUTH_SOCK}:${SSH_AUTH_SOCK} -e SSH_AUTH_SOCK="${SSH_AUTH_SOCK}" \
+-e REAL_USER="${USER}" \
+--net host -it --rm --gpus all \
+$IMAGE_NAME /bin/bash
diff --git a/download/download.sh b/download/download.sh
new file mode 100644
index 0000000..68585b3
--- /dev/null
+++ b/download/download.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+mkdir -p yalm100b_checkpoint/vocab yalm100b_checkpoint/weights
+
+cd yalm100b_checkpoint/vocab
+curl --remote-name-all https://yalm-100b.s3.mds.yandex.net/vocab/voc_100b.sp
+
+cd ../weights
+curl --remote-name-all https://yalm-100b.s3.mds.yandex.net/weights/layer_{00,01,[03-82],84}-model_00-model_states.pt
diff --git a/examples/example_cond_input.json b/examples/example_cond_input.json
new file mode 100644
index 0000000..0db0bed
--- /dev/null
+++ b/examples/example_cond_input.json
@@ -0,0 +1,3 @@
+{"prefix": "One, two, three, four, five, six, seven, "}
+{"prefix": "Everyone wants to know:"}
+{"prefix": "Once upon a time"}
diff --git a/examples/generate_conditional_greedy.sh b/examples/generate_conditional_greedy.sh
new file mode 100644
index 0000000..f35bbaa
--- /dev/null
+++ b/examples/generate_conditional_greedy.sh
@@ -0,0 +1,55 @@
+# Set visible devices
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# Set MP_SIZE to the number of devices
+MP_SIZE=8
+
+# Provide path to vocab file and model
+VOCAB_PATH="yalm100b_checkpoint/vocab/voc_100b.sp"
+MODEL_PATH="yalm100b_checkpoint/weights"
+LOAD_ARGS="\
+ --vocab-file ${VOCAB_PATH} \
+ --load ${MODEL_PATH}"
+
+# Set generation parameters
+COND_ARGS="\
+ --sample-input-file examples/example_cond_input.json \
+ --sample-output-file cond_output.json \
+ --sample-context-field prefix \
+ --sample-generated-field suffix"
+GEN_ARGS="
+ --greedy \
+ --seq-length 256 \
+ --out-seq-length 128"
+
+HPARAM_ARGS="\
+ --pos-encoding-type rotary \
+ --num-layers 80 \
+ --embedding-size 2048 \
+ --hidden-size 10240 \
+ --intermediate-size 27308 \
+ --activation-type geglu \
+ --num-attention-heads 128 \
+ --max-position-embeddings 1024 \
+ --tokenizer-type SentencePiece \
+ --fp16"
+
+DISTRIBUTED_ARGS="--nproc_per_node $MP_SIZE \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr localhost \
+ --master_port=1234"
+
+COMMON_ARGS="\
+ --num-samples 0 \
+ --load-release-checkpoint \
+ --batch-size 1 \
+ --model-parallel-size $MP_SIZE \
+ --make-vocab-size-divisible-by 1"
+
+torchrun $DISTRIBUTED_ARGS megatron_lm/tools/generate_samples_gpt2.py \
+ $LOAD_ARGS \
+ $HPARAM_ARGS \
+ $COMMON_ARGS \
+ $GEN_ARGS \
+ $COND_ARGS
diff --git a/examples/generate_conditional_sampling.sh b/examples/generate_conditional_sampling.sh
new file mode 100644
index 0000000..abff1d6
--- /dev/null
+++ b/examples/generate_conditional_sampling.sh
@@ -0,0 +1,57 @@
+# Set visible devices
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# Set MP_SIZE to the number of devices
+MP_SIZE=8
+
+# Provide path to vocab file and model
+VOCAB_PATH="yalm100b_checkpoint/vocab/voc_100b.sp"
+MODEL_PATH="yalm100b_checkpoint/weights"
+LOAD_ARGS="\
+ --vocab-file ${VOCAB_PATH} \
+ --load ${MODEL_PATH}"
+
+# Set generation parameters
+COND_ARGS="\
+ --sample-input-file examples/example_cond_input.json \
+ --sample-output-file cond_output.json \
+ --sample-context-field prefix \
+ --sample-generated-field suffix"
+GEN_ARGS="
+ --temperature 1.0 \
+ --top_p 0.9 \
+ --seed 1234 \
+ --seq-length 256 \
+ --out-seq-length 128"
+
+HPARAM_ARGS="\
+ --pos-encoding-type rotary \
+ --num-layers 80 \
+ --embedding-size 2048 \
+ --hidden-size 10240 \
+ --intermediate-size 27308 \
+ --activation-type geglu \
+ --num-attention-heads 128 \
+ --max-position-embeddings 1024 \
+ --tokenizer-type SentencePiece \
+ --fp16"
+
+DISTRIBUTED_ARGS="--nproc_per_node $MP_SIZE \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr localhost \
+ --master_port=1234"
+
+COMMON_ARGS="\
+ --num-samples 0 \
+ --load-release-checkpoint \
+ --batch-size 1 \
+ --model-parallel-size $MP_SIZE \
+ --make-vocab-size-divisible-by 1"
+
+torchrun $DISTRIBUTED_ARGS megatron_lm/tools/generate_samples_gpt2.py \
+ $LOAD_ARGS \
+ $HPARAM_ARGS \
+ $COMMON_ARGS \
+ $GEN_ARGS \
+ $COND_ARGS \ No newline at end of file
diff --git a/examples/generate_interactive.sh b/examples/generate_interactive.sh
new file mode 100644
index 0000000..e1b6b33
--- /dev/null
+++ b/examples/generate_interactive.sh
@@ -0,0 +1,51 @@
+# Set visible devices
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# Set MP_SIZE to the number of devices
+MP_SIZE=8
+
+# Provide path to vocab file and model
+VOCAB_PATH="yalm100b_checkpoint/vocab/voc_100b.sp"
+MODEL_PATH="yalm100b_checkpoint/weights"
+LOAD_ARGS="\
+ --vocab-file ${VOCAB_PATH} \
+ --load ${MODEL_PATH}"
+
+# Set generation parameters
+GEN_ARGS="
+ --temperature 1.0 \
+ --top_p 0.9 \
+ --seed 1234 \
+ --seq-length 256 \
+ --out-seq-length 128"
+
+HPARAM_ARGS="\
+ --pos-encoding-type rotary \
+ --num-layers 80 \
+ --embedding-size 2048 \
+ --hidden-size 10240 \
+ --intermediate-size 27308 \
+ --activation-type geglu \
+ --num-attention-heads 128 \
+ --max-position-embeddings 1024 \
+ --tokenizer-type SentencePiece \
+ --fp16"
+
+DISTRIBUTED_ARGS="--nproc_per_node $MP_SIZE \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr localhost \
+ --master_port=1234"
+
+COMMON_ARGS="\
+ --num-samples 0 \
+ --load-release-checkpoint \
+ --batch-size 1 \
+ --model-parallel-size $MP_SIZE \
+ --make-vocab-size-divisible-by 1"
+
+torchrun $DISTRIBUTED_ARGS megatron_lm/tools/generate_samples_gpt2.py \
+ $LOAD_ARGS \
+ $HPARAM_ARGS \
+ $COMMON_ARGS \
+ $GEN_ARGS
diff --git a/examples/generate_unconditional.sh b/examples/generate_unconditional.sh
new file mode 100644
index 0000000..715bc7b
--- /dev/null
+++ b/examples/generate_unconditional.sh
@@ -0,0 +1,54 @@
+# Set visible devices
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# Set MP_SIZE to the number of devices
+MP_SIZE=8
+
+# Provide path to vocab file and model
+VOCAB_PATH="yalm100b_checkpoint/vocab/voc_100b.sp"
+MODEL_PATH="yalm100b_checkpoint/weights"
+LOAD_ARGS="\
+ --vocab-file ${VOCAB_PATH} \
+ --load ${MODEL_PATH}"
+
+# Set generation parameters
+UNCOND_ARGS="\
+ --num-samples 2 \
+ --genfile uncond_output.json"
+GEN_ARGS="
+ --temperature 1.0 \
+ --top_p 0.9 \
+ --seed 1234 \
+ --seq-length 512 \
+ --out-seq-length 512"
+
+HPARAM_ARGS="\
+ --pos-encoding-type rotary \
+ --num-layers 80 \
+ --embedding-size 2048 \
+ --hidden-size 10240 \
+ --intermediate-size 27308 \
+ --activation-type geglu \
+ --num-attention-heads 128 \
+ --max-position-embeddings 1024 \
+ --tokenizer-type SentencePiece \
+ --fp16"
+
+DISTRIBUTED_ARGS="--nproc_per_node $MP_SIZE \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr localhost \
+ --master_port=1234"
+
+COMMON_ARGS="\
+ --load-release-checkpoint \
+ --batch-size 1 \
+ --model-parallel-size $MP_SIZE \
+ --make-vocab-size-divisible-by 1"
+
+torchrun $DISTRIBUTED_ARGS megatron_lm/tools/generate_samples_gpt2.py \
+ $LOAD_ARGS \
+ $HPARAM_ARGS \
+ $COMMON_ARGS \
+ $GEN_ARGS \
+ $UNCOND_ARGS \ No newline at end of file
diff --git a/megatron_lm/LICENSE b/megatron_lm/LICENSE
new file mode 100644
index 0000000..c4bad15
--- /dev/null
+++ b/megatron_lm/LICENSE
@@ -0,0 +1,264 @@
+The following applies to all files unless otherwise noted:
+
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+--
+
+This repository also contains code from Hugging Face Inc., Google Research,
+and Facebook (from their Fairseq project). Files from these
+organizations have notices at the top of each file. Below are licenses
+used in those files, as indicated.
+
+
+------------- LICENSE FOR huggingface and Google Research code --------------
+
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+------------- LICENSE FOR Facebook Fairseq code --------------
+
+MIT License
+
+Copyright (c) Facebook, Inc. and its affiliates.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/megatron_lm/MANIFEST.in b/megatron_lm/MANIFEST.in
new file mode 100644
index 0000000..f447911
--- /dev/null
+++ b/megatron_lm/MANIFEST.in
@@ -0,0 +1,2 @@
+include megatron/data/Makefile
+include megatron/data/helpers.cpp
diff --git a/megatron_lm/README.md b/megatron_lm/README.md
new file mode 100644
index 0000000..a8a58b1
--- /dev/null
+++ b/megatron_lm/README.md
@@ -0,0 +1,568 @@
+[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
+
+Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT models with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%.
+
+Our codebase is capable of efficiently training very large (several billion parameter) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs we consider the following GPT-2 model sizes. All models use a vocabulary size of 51,200 and a sequence length of 1024.
+
+![Cases](images/cases.png)
+
+The table below details the weak scaling from 1 to 8 GPUs of our model parallelism code in both a DGX-2 and a DGX-A100. Notice that we double the batch size on the DGX-A100 but the iteration time decreases compared to the DGX-2 resulting in a **2.1x** speedup for the end-to-end application.
+
+![Model Parallel Scaling](images/scaling-mp.png)
+
+The following table details how Megatron scales using data parallelism in conjuction with model parallelism in a cluster of DGX-A100s. All of these cases use 128-way data parallelism and the scaling numbers are relative to a single A100 (Case 1B with a 1076ms iteration time).
+
+![Data Parallel Scaling](images/scaling-dp.png)
+
+<a id="contents"></a>
+# Contents
+<!-- MarkdownTOC -->
+
+- [Setup](#setup)
+ - [Downloading Checkpoints](#downloading-checkpoints)
+- [Usage](#usage)
+- [Training](#training)
+ - [Data Preprocessing](#data-preprocessing)
+ - [BERT Pretraining](#bert-pretraining)
+ - [GPT-2 Pretraining](#gpt-2-pretraining)
+ - [Distributed BERT or GPT-2 Pretraining](#distributed-bert-or-gpt-2-pretraining)
+- [REALM Pipeline](#realm)
+- [Evaluation and Tasks](#evaluation-and-tasks)
+ - [GPT-2 Text Generation](#gpt-2-text-generation)
+ - [GPT-2 Evaluation](#gpt-2-evaluation)
+ - [WikiText Perplexity Evaluation](#wikitext-perplexity-evaluation)
+ - [LAMBADA Cloze Accuracy](#lambada-cloze-accuracy)
+ - [BERT Task Evaluation](#bert-task-evaluation)
+ - [RACE Evaluation](#race-evaluation)
+ - [MNLI Evaluation](#mnli-evaluation)
+- [Datasets](#datasets)
+ - [Collecting Wikipedia Training Data](#collecting-wikipedia-training-data)
+ - [Collecting GPT-2 Webtext Data](#collecting-gpt-2-webtext-data)
+
+<!-- /MarkdownTOC -->
+
+<a id="setup"></a>
+# Setup
+We officially support only python 3.6, pytorch 1.5, cuda 10, and nccl 2.6 versions and above.
+
+To use this repo please install the latest supported versions of PyTorch with GPU support and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start). We strongly recommend using one of [NGC's recent PyTorch containers](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) (the latest compatible version at time of publication can be pulled with `docker pull nvcr.io/nvidia/pytorch:20.03-py3`). Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation or downstream tasks.
+
+To use megatron you can either clone the repo or install it via pip (make sure python3-dev is installed):
+<pre>
+pip install megatron-lm
+</pre>
+
+<a id="downloading-checkpoints"></a>
+## Downloading Checkpoints
+We've provided two pretrained checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first please [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI.
+
+The checkpoints can be downloaded with:
+<pre>
+ngc registry model download-version --dest &#60;output_base_directory&#62; nvidia/&#60;model_name&#62;:&#60;version&#62;
+</pre>
+
+The available models along with `<model_name>:<version>` are below:
+* [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m): megatron\_bert\_345m:v0.0
+* [GPT-2-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m): megatron\_lm\_345m:v0.0
+
+The models require vocabulary files to run. The BERT uncased WordPiece vocab file can be extracted from Google's [pretrained BERT models](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt). The GPT-2 [vocab file](https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json) and [merge table](https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt) can be downloaded directly.
+
+Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1)
+
+<a id="usage"></a>
+# Usage
+
+After installation, there are several possible workflows. The most comprehensive is:
+1. Data preprocessing
+2. Pretraining
+3. Finetuning (Optional for zero-shot tasks)
+4. Downstream task evaluation or text generation
+
+However, steps 1 and 2 can be replaced by using one of the pretrained models mentioned above.
+
+We've provided several scripts for pretraining both BERT and GPT-2 in [`examples`](./examples) directory, as well as scripts for both zero-shot and fine-tuned downstream tasks including MNLI, RACE, WikiText103, and LAMBADA evaluation. There is also a script for GPT-2 interactive text generation.
+
+<a id="training"></a>
+# Training
+<a id="data-preprocessing"></a>
+## Data Preprocessing
+We support three file formats for training, but all require preprocessing. First, place your training data in a loose json format, with one json containing a text sample per line. For example:
+<pre>
+{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
+{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
+</pre>
+
+The name of the `text` field of the json can be changed by using the `--json-key` flag in [`preprocess_data.py`](./tools/preprocess_data.py) The other metadata are optional and are not used in training.
+
+The loose json is then processed into a binary format for training. To convert the json into mmap, cached index file, or the lazy loader format use `preprocess_data.py`. Set the `--dataset-impl` flag to `mmap`, `cached`, or `lazy`, respectively (default is `mmap`). An example script to prepare data for BERT training is:
+<pre>
+python tools/preprocess_data.py \
+ --input my-corpus.json \
+ --output-prefix my-bert \
+ --vocab bert-vocab.txt \
+ --dataset-impl mmap \
+ --tokenizer-type BertWordPieceLowerCase \
+ --split-sentences
+</pre>
+
+The output will be two files named, in this case, `my-bert_text_sentence.bin` and `my-bert_text_sentence.idx`. The `--data-path` specified in later BERT training is the full path and new filename, but without the file extension.
+
+Some minor modifications are required for GPT-2 data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type:
+<pre>
+python tools/preprocess_data.py \
+ --input my-corpus.json \
+ --output-prefix my-gpt2 \
+ --vocab gpt2-vocab.json \
+ --dataset-impl mmap \
+ --tokenizer-type GPT2BPETokenizer \
+ --merge-file gpt2-merges.txt \
+ --append-eod
+</pre>
+
+Here the output files are named `my-gpt2_text_document.bin` and `my-gpt2_text_document.idx`. As before, in GPT-2 training, use the longer name without the extension as `--data-path`.
+
+Further command line arguments are described in the source file [`preprocess_data.py`](./tools/preprocess_data.py).
+
+<a id="bert-pretraining"></a>
+## BERT Pretraining
+`bash examples/pretrain_bert.sh`
+
+This script runs single GPU 345M parameter BERT pretraining. Debugging is the primary use for single GPU training, as the code base and command line arguments are optimized for highly distributed training. Most of the arguments are fairly self-explanatory. By default, the learning rate decays linearly over the training iterations starting at `--lr` to a minimum set by `--min-lr` over `--lr-decay-iters` iterations. The fraction of training iterations used for warmup is set by `--warmup`. While this is single GPU training, the batch size specified by `--batch-size` is per GPU used for data parallelism. The data is partitioned into a 949:50:1 ratio for training/validation/test sets (default is 969:30:1). This partitioning happens on the fly, but is consistent across runs with the same random seed (1234 by default, or specified manually with `--seed`).
+
+The logging, checkpoint-saving, and evaluation intervals are specified. Checkpointing the activations facilitates the training of larger models and/or batches. Note that the `--data-path` now includes the additional `_text_sentence` suffix added in preprocessing, but does not include the file extensions.
+
+<pre>
+CHECKPOINT_PATH=checkpoints/bert_345m
+VOCAB_FILE=bert-vocab.txt
+DATA_PATH=my-bert_text_sentence
+
+BERT_ARGS="--num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --seq-length 512 \
+ --max-position-embeddings 512 \
+ --lr 0.0001 \
+ --train-iters 2000000 \
+ --min-lr 0.00001 \
+ --lr-decay-iters 990000 \
+ --warmup 0.01 \
+ --batch-size 8 \
+ --vocab-file $VOCAB_FILE \
+ --split 949,50,1 \
+ --fp16"
+
+OUTPUT_ARGS="--log-interval 10 \
+ --save-interval 500 \
+ --eval-interval 100 \
+ --eval-iters 10 \
+ --checkpoint-activations"
+
+python pretrain_bert.py \
+ $BERT_ARGS \
+ $OUTPUT_ARGS \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH
+</pre>
+
+Further command line arguments are described in the source file [`arguments.py`](./megatron/arguments.py).
+
+<a id="gpt-2-pretraining"></a>
+## GPT-2 Pretraining
+`bash examples/pretrain_gpt2.sh`
+
+This script runs single GPU 345M parameter GPT-2 pretraining. As mentioned above, single GPU training is primarily intended for debugging purposes, as the code is optimized for distributed training.
+
+It follows largely the same format as the previous BERT script with a few notable differences: the tokenization scheme used is BPE (which requires a merge table and a `json` vocabulary file) instead of WordPiece, the model architecture allows for longer sequences (note that the max position embedding must be greater than or equal to the maximum sequence length), and the `--lr-decay-style` has been set to cosine decay. Note that the `--data-path` now includes the additional `_text_document` suffix added in preprocessing, but does not include the file extensions.
+
+<pre>
+CHECKPOINT_PATH=checkpoints/gpt2_345m
+VOCAB_FILE=gpt2-vocab.json
+MERGE_FILE=gpt2-merges.txt
+DATA_PATH=my-gpt2_text_document
+
+GPT2_ARGS="--num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --seq-length 1024 \
+ --max-position-embeddings 1024 \
+ --batch-size 4 \
+ --lr 0.00015 \
+ --train-iters 500000 \
+ --lr-decay-iters 320000 \
+ --lr-decay-style cosine \
+ --vocab-file $VOCAB_FILE \
+ --merge-file $MERGE_FILE \
+ --warmup .01 \
+ --fp16"
+
+OUTPUT_ARGS=&#60;same as those in <a href="#bert-pretraining">BERT pretraining</a> above&#62;
+
+python pretrain_gpt2.py \
+ $GPT2_ARGS \
+ $OUTPUT_ARGS \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+</pre>
+
+Further command line arguments are described in the source file [`arguments.py`](./megatron/arguments.py).
+
+<a id="distributed-bert-or-gpt-2-pretraining"></a>
+## Distributed BERT or GPT-2 Pretraining
+`bash examples/pretrain_bert_distributed.sh`
+
+`bash examples/pretrain_gpt2_distributed.sh`
+
+These scripts use the PyTorch distributed launcher for distributed training. As such, multinode training can be achieved by properly setting environment variables and using `init_method='env://'` in the launcher. See the official PyTorch [documentation](https://pytorch.org/docs/stable/distributed.html#launch-utility) for further description of these [environment variables](https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization). By default, multinode training uses the [nccl](https://developer.nvidia.com/nccl) distributed backend. A simple set of additional arguments and the use of the PyTorch distributed module with the Python flag `-m torch.distributed.launch`, detailed below, are the only additional requirements to adopt distributed training.
+
+The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time.
+
+Second, we developed a simple and efficient intra-layer model parallel approach. To use model parallelism, add the `--model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--model-parallel-size` is 1, which will not implement model parallelism.
+
+Other than these minor changes, the distributed training is identical to the training on a single GPU.
+
+Distributed BERT training:
+<pre>
+WORLD_SIZE=8
+MP_SIZE=2
+
+DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr localhost \
+ --master_port 6000"
+
+CHECKPOINT_PATH=checkpoints/bert_345m
+VOCAB_FILE=bert-vocab.txt
+DATA_PATH=my-bert_text_sentence
+BERT_ARGS=&#60;same as those in <a href="#bert-pretraining">BERT pretraining</a> above&#62;
+OUTPUT_ARGS=&#60;same as those in <a href="#bert-pretraining">BERT pretraining</a> above&#62;
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \
+ $BERT_ARGS \
+ $OUTPUT_ARGS \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --model-parallel-size $MP_SIZE \
+ --DDP-impl torch
+</pre>
+
+Distributed GPT-2 training:
+<pre>
+WORLD_SIZE=8
+MP_SIZE=2
+
+DISTRIBUTED_ARGS=&#60;same as those directly above&#62;
+
+CHECKPOINT_PATH=checkpoints/gpt2_345m
+VOCAB_FILE=gpt2-vocab.json
+MERGE_FILE=gpt2-merges.txt
+DATA_PATH=my-gpt2_text_document
+GPT2_ARGS=&#60;same as those in <a href="#gpt-2-pretraining">GPT-2 pretraining</a> above&#62;
+OUTPUT_ARGS=&#60;same as those in <a href="#bert-pretraining">BERT pretraining</a> above&#62;
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \
+ $GPT2_ARGS \
+ $OUTPUT_ARGS \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --model-parallel-size $MP_SIZE \
+ --DDP-impl torch
+
+</pre>
+
+<a id="realm"></a>
+## REALM Pipeline
+We are working on implementing the [REALM](https://arxiv.org/pdf/2002.08909.pdf) system. The following sections (will) reflect the three stages of training it. For now it's just the ICT code.
+Loosely, they are pretraining the retriever modules, then jointly training the language model and the retriever, and then finetuning a question answering head on the language model with fixed retriever.
+
+### Inverse Cloze Task (ICT) Pretraining
+1. Have a corpus in loose JSON format with the intention of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document.
+Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. For the original REALM system, we construct two datasets, one with the title of every document, and another with the body.
+Refer to the following script
+<pre>
+python preprocess_data.py \
+ --input /path/to/corpus.json \
+ --json-keys text title \
+ --split-sentences \
+ --tokenizer-type BertWordPieceLowerCase \
+ --vocab-file /path/to/vocab.txt \
+ --output-prefix corpus_indexed \
+ --workers 5 # works well for 10 CPU cores. Scale up accordingly.
+</pre>
+
+2. Use a custom samples mapping function in place of `megatron/data/realm_dataset_utils.get_block_samples_mapping` if required. To do this, you will need to implement a new function in C++ inside of `megatron/data/helpers.cpp`. The samples mapping data structure is used to select the data that will constitute every training sample in advance of the training loop.
+ The samples mapping is responsible for holding all of the required metadata needed to construct the sample from one or more indexed datasets. In REALM, the samples mapping contains the start and end sentence indices, as well as the document index (to find the correct title for a body) and a unique ID for every block.
+3. Pretrain a BERT language model using `pretrain_bert.py`, with the sequence length equal to the block size in token ids. This model should be trained on the same indexed dataset that is used to supply the blocks for the information retrieval task.
+In REALM, this is an uncased bert base model trained with the standard hyperparameters.
+4. Use `pretrain_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with.
+The script below trains the ICT model from REALM. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument. The batch size used in the paper is 4096, so this would need to be run with data parallel world size 32.
+<pre>
+python pretrain_ict.py \
+ --num-layers 12 \
+ --num-attention-heads 12 \
+ --hidden-size 768 \
+ --batch-size 128 \
+ --seq-length 256 \
+ --max-position-embeddings 256 \
+ --ict-head-size 128 \
+ --train-iters 100000 \
+ --checkpoint-activations \
+ --bert-load /path/to/pretrained_bert \
+ --load checkpoints \
+ --save checkpoints \
+ --data-path /path/to/indexed_dataset \
+ --titles-data-path /path/to/titles_indexed_dataset \
+ --vocab-file /path/to/vocab.txt \
+ --lr 0.0001 \
+ --num-workers 2 \
+ --lr-decay-style linear \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --warmup .01 \
+ --save-interval 3000 \
+ --query-in-block-prob 0.1 \
+ --fp16
+
+</pre>
+
+### Building an Index of Block Embeddings
+After having trained an ICT model, you can now embed an entire dataset of blocks by creating a `BlockData` structure. After that has been saved, you can load it
+and wrap it with a `FaissMIPSIndex` to do fast similarity search which is key in the learned information retrieval pipeline. The initial index can be built with the following script, meant to be run in an interactive session. It can leverage multiple GPUs on multiple nodes to index large datasets much more quickly.
+
+<pre>
+python tools/create_doc_index.py \
+ --num-layers 12 \
+ --hidden-size 768 \
+ --ict-head-size 128 \
+ --num-attention-heads 12 \
+ --batch-size 128 \
+ --checkpoint-activations \
+ --seq-length 256 \
+ --max-position-embeddings 256 \
+ --ict-load /path/to/pretrained_ict \
+ --data-path /path/to/indexed_dataset \
+ --titles-data-path /path/to/titles_indexed_dataset \
+ --block-data-path embedded_blocks.pkl \
+ --indexer-log-interval 1000 \
+ --indexer-batch-size 128 \
+ --vocab-file /path/to/vocab.txt \
+ --num-workers 2 \
+ --fp16
+</pre>
+
+<a id="evaluation-and-tasks"></a>
+# Evaluation and Tasks
+
+We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.
+
+Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
+
+<pre>
+MODEL_PARALLEL_SIZE=2
+
+VOCAB_FILE=bert-vocab.txt
+CHECKPOINT_PATH=checkpoints/bert_345m
+
+WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
+ --model-type BERT \
+ --model-parallel-size $MODEL_PARALLEL_SIZE \
+ --tokenizer-type BertWordPieceLowerCase \
+ --vocab-file $VOCAB_FILE \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --seq-length 512 \
+ --max-position-embeddings 512 \
+ --load $CHECKPOINT_PATH
+
+</pre>
+
+Several downstream tasks are described for both GPT-2 and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.
+
+<a id="gpt-2-text-generation"></a>
+## GPT-2 Text Generation
+`bash examples/generate_text.sh`
+
+We generate text samples using largely the GPT-2 pretraining script. Few changes need to make, such as we need to provide the path to the pretrained checkpoint, the length of the output samples, whether to generate texts unconditionally (`--num-samples` to denote how many samples to generate) or conditional (need to pass `--sample-input-file <filename>` where each line of the file will be used as the conditional texts). There are few optional parameters to play, e.g. `top-k`, `top-p`, or `greedy` (set top-k and top-p to 0) sampling..
+
+<pre>
+CHECKPOINT_PATH=checkpoints/gpt2_345m
+VOCAB_FILE=gpt2-vocab.json
+MERGE_FILE=gpt2-merges.txt
+GPT2_ARGS=&#60;same as those in <a href="#gpt-2-pretraining">GPT-2 pretraining</a> above&#62;
+
+MAX_OUTPUT_SEQUENCE_LENGTH=1024
+TEMPERATURE=1.0
+TOP_P=0.9
+NUMBER_OF_SAMPLES=2
+OUTPUT_FILE=samples.json
+
+python tools/generate_samples_gpt2.py \
+ $GPT2_ARGS \
+ --load $CHECKPOINT_PATH \
+ --out-seq-length $MAX_OUTPUT_SEQUENCE_LENGTH \
+ --temperature $TEMPERATURE \
+ --genfile $OUTPUT_FILE \
+ --num-samples $NUMBER_OF_SAMPLES \
+ --top_p $TOP_P \
+ --recompute
+</pre>
+
+<a id="gpt-2-evaluation"></a>
+## GPT-2 Evaluation
+We include example scripts for GPT-2 evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy.
+
+<a id="wikitext-perplexity-evaluation"></a>
+### WikiText Perplexity Evaluation
+For even comparison with prior works, we evaluate perplexity on the word-level [WikiText-103 test dataset](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip), and appropriately compute perplexity given the change in tokens when using our subword tokenizer.
+
+We use the following command to run WikiText-103 evaluation on a 345M parameter model.
+<pre>
+TASK="WIKITEXT103"
+
+VALID_DATA=&#60;wikitext path&#62;.txt
+VOCAB_FILE=gpt2-vocab.json
+MERGE_FILE=gpt2-merges.txt
+CHECKPOINT_PATH=checkpoints/gpt2_345m
+
+COMMON_TASK_ARGS="--num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --seq-length 1024 \
+ --max-position-embeddings 1024 \
+ --fp16 \
+ --vocab-file $VOCAB_FILE"
+
+python tasks/main.py \
+ --task $TASK \
+ $COMMON_TASK_ARGS \
+ --valid-data $VALID_DATA \
+ --tokenizer-type GPT2BPETokenizer \
+ --merge-file $MERGE_FILE \
+ --load $CHECKPOINT_PATH \
+ --batch-size 8 \
+ --checkpoint-activations \
+ --log-interval 10 \
+ --no-load-optim \
+ --no-load-rng
+</pre>
+
+
+<a id="lambada-cloze-accuracy"></a>
+### LAMBADA Cloze Accuracy
+To compute LAMBADA cloze accuracy (the accuracy of predicting the last token given the preceeding tokens) we utilize a detokenized, processed version of the [LAMBADA dataset](https://github.com/cybertronai/bflm/blob/master/lambada_test.jsonl).
+
+We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the `--strict-lambada` flag should be used to require whole word matching. Make that `lambada` is part of the file path.
+
+<pre>
+TASK="LAMBADA"
+
+VALID_DATA=&#60;lambada path&#62;.json
+VOCAB_FILE=gpt2-vocab.json
+MERGE_FILE=gpt2-merges.txt
+CHECKPOINT_PATH=checkpoints/gpt2_345m
+COMMON_TASK_ARGS=&#60;same as those in <a href="#wikitext-perplexity-evaluation">WikiText Perplexity Evaluation</a> above&#62;
+
+python tasks/main.py \
+ --task $TASK \
+ $COMMON_TASK_ARGS \
+ --valid-data $VALID_DATA \
+ --tokenizer-type GPT2BPETokenizer \
+ --strict-lambada \
+ --merge-file $MERGE_FILE \
+ --load $CHECKPOINT_PATH \
+ --batch-size 8 \
+ --checkpoint-activations \
+ --log-interval 10 \
+ --no-load-optim \
+ --no-load-rng
+</pre>
+
+Further command line arguments are described in the source file [`main.py`](./tasks/main.py)
+
+<a id="bert-task-evaluation"></a>
+## BERT Task Evaluation
+<a id="race-evaluation"></a>
+### RACE Evaluation
+The following script finetunes the BERT model for evaluation on the [RACE dataset](http://www.cs.cmu.edu/~glai1/data/race/). The `TRAIN_DATA` and `VALID_DATA` directory contain the RACE dataset as separate `.txt` files.
+
+<pre>
+TRAIN_DATA="data/RACE/train/middle"
+VALID_DATA="data/RACE/dev/middle \
+ data/RACE/dev/high"
+VOCAB_FILE=bert-vocab.txt
+PRETRAINED_CHECKPOINT=checkpoints/bert_345m
+CHECKPOINT_PATH=checkpoints/bert_345m_race
+COMMON_TASK_ARGS="--num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --seq-length 512 \
+ --max-position-embeddings 512 \
+ --fp16 \
+ --vocab-file $VOCAB_FILE"
+
+COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \
+ --valid-data $VALID_DATA \
+ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
+ --checkpoint-activations \
+ --save-interval 10000 \
+ --save $CHECKPOINT_PATH \
+ --log-interval 100 \
+ --eval-interval 1000 \
+ --eval-iters 10 \
+ --weight-decay 1.0e-1"
+
+python tasks/main.py \
+ --task RACE \
+ $COMMON_TASK_ARGS \
+ $COMMON_TASK_ARGS_EXT \
+ --tokenizer-type BertWordPieceLowerCase \
+ --epochs 3 \
+ --batch-size 4 \
+ --lr 1.0e-5 \
+ --warmup 0.06
+</pre>
+
+<a id="mnli-evaluation"></a>
+### MNLI Evaluation
+The following script finetunes the BERT model for evaluation with the [MultiNLI sentence pair corpus](https://www.nyu.edu/projects/bowman/multinli/). Because the matching tasks are quite similar, the script can be quickly tweaked to work with the [Quora Question Pairs](https://www.kaggle.com/quora/question-pairs-dataset) (QQP) dataset as well.
+
+<pre>
+
+TRAIN_DATA="data/glue_data/MNLI/train.tsv"
+VALID_DATA="data/glue_data/MNLI/dev_matched.tsv \
+ data/glue_data/MNLI/dev_mismatched.tsv"
+PRETRAINED_CHECKPOINT=checkpoints/bert_345m
+VOCAB_FILE=bert-vocab.txt
+CHECKPOINT_PATH=checkpoints/bert_345m_mnli
+COMMON_TASK_ARGS=&#60;same as those in <a href="#race-evaluation">RACE Evaluation</a> above&#62;
+COMMON_TASK_ARGS_EXT=&#60;same as those in <a href="#race-evaluation">RACE Evaluation</a> above&#62;
+
+python tasks/main.py \
+ --task MNLI \
+ $COMMON_TASK_ARGS \
+ $COMMON_TASK_ARGS_EXT \
+ --tokenizer-type BertWordPieceLowerCase \
+ --epochs 5 \
+ --batch-size 8 \
+ --lr 5.0e-5 \
+ --warmup 0.065
+</pre>
+
+<a id="datasets"></a>
+# Datasets
+We do not host any datasets for GPT-2 or BERT training, however, we detail their collection so that our results may be reproduced.
+
+<a id="collecting-wikipedia-training-data"></a>
+## Collecting Wikipedia Training Data
+We recommend following the Wikipedia data extraction process specified by Google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text."
+
+We recommend using the `--json` argument when using WikiExtractor, which will dump the Wikipedia data into loose json format (one json per line), making it more manageable on the file system and also readily consumable by our codebase. We recommend further preprocessing this json dataset by nltk punctuation standardization. For BERT training, add newlines between sentences during data preprocessing. This is done with the `--split-sentences` flag in `preprocess_data.py` as described [above](#data-preprocessing). (Note that if you'd like to use Wikipedia data for GPT-2 training you should still clean it with nltk/spacy/ftfy, but do not split it into newline separated sentences.)
+
+<a id="collecting-gpt-2-webtext-data"></a>
+## Collecting GPT-2 Webtext Data
+We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library from [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in our [openwebtext](./tools/openwebtext) directory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content.
diff --git a/megatron_lm/changes.md b/megatron_lm/changes.md
new file mode 100644
index 0000000..5ffcb87
--- /dev/null
+++ b/megatron_lm/changes.md
@@ -0,0 +1 @@
+PRETEND THESE ARE CODE CHANGES
diff --git a/megatron_lm/curriculum_learning/README.md b/megatron_lm/curriculum_learning/README.md
new file mode 100644
index 0000000..a80e351
--- /dev/null
+++ b/megatron_lm/curriculum_learning/README.md
@@ -0,0 +1 @@
+This is an example of how to use DeepSpeed's curriculum learning (CL) feature which provides faster and more stable language model pre-training. Currently it is only integrated for GPT pre-training. Note that there are two curriculum learning examples in two different repos for Megatron-LM GPT-2 pre-training. Both of them have some unique features and limitations. See details in our [tutorial](https://www.deepspeed.ai/tutorials/curriculum-learning/). For technical details please refer to our [paper](https://arxiv.org/abs/2108.06084). \ No newline at end of file
diff --git a/megatron_lm/curriculum_learning/ds_pretrain_gpt2.sh b/megatron_lm/curriculum_learning/ds_pretrain_gpt2.sh
new file mode 100644
index 0000000..338b93f
--- /dev/null
+++ b/megatron_lm/curriculum_learning/ds_pretrain_gpt2.sh
@@ -0,0 +1,183 @@
+#! /bin/bash
+
+CONFIG=$1
+TAG=$2
+MODEL_SIZE=$3
+LR=$4
+TOTAL_BATCHSIZE=$5
+SEQ_LEN=$6
+MP_SIZE=$7
+SEED=$8
+SAVE_INTERVAL=$9
+NUM_ITER=${10}
+NUM_TOKEN=${11}
+LR_DECAY_TOKEN=${12}
+LR_WARMUP_ITER=${13}
+CONFIG_TEMPLATE=${14}
+CURRICULUM_STEP=${15}
+CURRICULUM_MIN=${16}
+
+# 12-layer, 768-hidden, 12-heads, 117M parameters
+# 24-layer, 1024-hidden, 16-heads, 345M parameters
+# 36-layer, 1280-hidden, 20-heads, 774M parameters
+# 48-layer, 1600-hidden, 25-heads, 1558M parameters
+if [[ $MODEL_SIZE -eq 117 ]]; then
+ NUM_LAYERS=12
+ HIDDEN_SIZE=768
+ NUM_ATTN_HEADS=12
+elif [[ $MODEL_SIZE -eq 345 ]]; then
+ NUM_LAYERS=24
+ HIDDEN_SIZE=1024
+ NUM_ATTN_HEADS=16
+elif [[ $MODEL_SIZE -eq 774 ]]; then
+ NUM_LAYERS=36
+ HIDDEN_SIZE=1280
+ NUM_ATTN_HEADS=20
+elif [[ $MODEL_SIZE -eq 1558 ]]; then
+ NUM_LAYERS=48
+ HIDDEN_SIZE=1600
+ NUM_ATTN_HEADS=25
+else
+ echo "Model size not supported."
+ exit 1
+fi
+
+# Change for multinode config
+NUM_WORKERS=16
+NUM_GPUS_PER_WORKER=8
+BATCHSIZE=$((MP_SIZE*TOTAL_BATCHSIZE/NUM_WORKERS/NUM_GPUS_PER_WORKER)) # per gpu batch size
+
+DATA_PATH=/vc_data/Megatron-LM/data/indexed_datasets/megatron
+VOCAB_PATH=/vc_data/Megatron-LM/data/gpt2-vocab.json
+MERGE_PATH=/vc_data/Megatron-LM/data/gpt2-merges.txt
+
+#ZeRO Configs
+stage=2
+reduce_scatter=true
+contigious_gradients=true
+rbs=50000000
+agbs=5000000000
+
+current_time=$(date "+%Y.%m.%d-%H.%M.%S")
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+host="${HOSTNAME}"
+
+if [ "${CONFIG_TEMPLATE}" = "true" ]; then
+template_json="$script_dir/ds_zero_stage_${stage}_config_${CONFIG}.json"
+config_json="$script_dir/ds_zero_stage_${stage}_config_${CONFIG}_min${CURRICULUM_MIN}_max${SEQ_LEN}_step${CURRICULUM_STEP}.json"
+sed "s/CONFIG_CL_MIN/${CURRICULUM_MIN}/" ${template_json} \
+ | sed "s/CONFIG_CL_MAX/${SEQ_LEN}/" \
+ | sed "s/CONFIG_CL_DURATION/${CURRICULUM_STEP}/" \
+ > ${config_json}
+else
+config_json="$script_dir/ds_zero_stage_${stage}_config_${CONFIG}.json"
+fi
+
+JOB_NAME="gpt2_${MODEL_SIZE}M_bsz${TOTAL_BATCHSIZE}_seq${SEQ_LEN}_lr${LR}_warmup${LR_WARMUP_ITER}_decay${LR_DECAY_TOKEN}_seed${SEED}_${TAG}_stage${stage}_n${NUM_WORKERS}_g${NUM_GPUS_PER_WORKER}_mp${MP_SIZE}"
+LOG_NAME="${JOB_NAME}_${host}_${current_time}"
+
+#Actication Checkpointing and Contigious Memory
+chkp_layers=1
+PA=true
+PA_CPU=false
+CC=true
+SYNCHRONIZE=true
+PROFILE=false
+
+OUTPUT_BASEPATH="/vc_data_blob/users/conglli"
+mkdir -p "${OUTPUT_BASEPATH}/tensorboard/curriculum/"
+mkdir -p "${OUTPUT_BASEPATH}/checkpoint/curriculum/"
+mkdir -p "${OUTPUT_BASEPATH}/log/curriculum/"
+LOGDIR="${OUTPUT_BASEPATH}/tensorboard/curriculum/${LOG_NAME}"
+CHECKPOINT_PATH="${OUTPUT_BASEPATH}/checkpoint/curriculum/${JOB_NAME}"
+
+gpt_options=" \
+ --model-parallel-size ${MP_SIZE} \
+ --num-layers $NUM_LAYERS \
+ --hidden-size $HIDDEN_SIZE \
+ --num-attention-heads $NUM_ATTN_HEADS \
+ --seq-length $SEQ_LEN \
+ --max-position-embeddings $SEQ_LEN \
+ --batch-size $BATCHSIZE \
+ --train-iters $NUM_ITER \
+ --train-tokens $NUM_TOKEN \
+ --lr-decay-tokens $LR_DECAY_TOKEN \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --vocab-file $VOCAB_PATH \
+ --merge-file $MERGE_PATH \
+ --data-impl mmap \
+ --split 949,50,1 \
+ --distributed-backend nccl \
+ --lr $LR \
+ --lr-decay-style cosine \
+ --min-lr 1.0e-5 \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --warmup-iters $LR_WARMUP_ITER \
+ --checkpoint-activations \
+ --log-interval 100 \
+ --save-interval $SAVE_INTERVAL \
+ --eval-interval 100 \
+ --eval-iters 10 \
+ --fp16 \
+ --seed $SEED \
+ --tensorboard-dir ${LOGDIR}
+"
+
+deepspeed_options=" \
+ --deepspeed \
+ --deepspeed_config ${config_json} \
+ --zero-stage ${stage} \
+ --zero-reduce-bucket-size ${rbs} \
+ --zero-allgather-bucket-size ${agbs}
+"
+
+if [ "${contigious_gradients}" = "true" ]; then
+deepspeed_options="${deepspeed_options} \
+ --zero-contigious-gradients"
+fi
+
+if [ "${reduce_scatter}" = "true" ]; then
+deepspeed_options="${deepspeed_options} \
+ --zero-reduce-scatter"
+fi
+
+chkp_opt=" \
+--deepspeed-activation-checkpointing \
+--checkpoint-num-layers ${chkp_layers}"
+
+if [ "${PA}" = "true" ]; then
+chkp_opt="${chkp_opt} --partition-activations"
+fi
+
+if [ "${PA_CPU}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --checkpoint-in-cpu"
+fi
+
+if [ "${SYNCHRONIZE}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --synchronize-each-layer"
+fi
+
+if [ "${CC}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --contigious-checkpointing"
+fi
+
+if [ "${PROFILE}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --profile-backward"
+fi
+
+
+full_options="${gpt_options} ${deepspeed_options} ${chkp_opt}"
+
+run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} ../pretrain_gpt2.py ${@:17} ${full_options} &>> ${OUTPUT_BASEPATH}/log/curriculum/${JOB_NAME}.log"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/megatron_lm/curriculum_learning/ds_train.sh b/megatron_lm/curriculum_learning/ds_train.sh
new file mode 100644
index 0000000..aac11ab
--- /dev/null
+++ b/megatron_lm/curriculum_learning/ds_train.sh
@@ -0,0 +1,37 @@
+# # baseline
+# CONFIG=baseline
+# TAG=baseline
+# MODEL_SIZE=1558
+# LR=1.5e-4
+# BSZ=512
+# SEQ_LEN=1024
+# MP_SIZE=1
+# SEED=1234
+# SAVE_INTERVAL=5000
+# NUM_ITER=600000
+# NUM_TOKEN=157286400000
+# LR_DECAY_TOKEN=157286400000
+# LR_WARMUP_ITER=3000
+# CONFIG_TEMPLATE=false
+# CURRICULUM_STEP=0
+# CURRICULUM_MIN=0
+
+# curriculum learning
+CONFIG=curriculum_fixed_linear
+MODEL_SIZE=1558
+LR=6e-4
+BSZ=4096
+SEQ_LEN=1024
+MP_SIZE=1
+SEED=1234
+SAVE_INTERVAL=1000
+NUM_ITER=75000
+NUM_TOKEN=157286400000
+LR_DECAY_TOKEN=157286400000
+LR_WARMUP_ITER=3000
+CONFIG_TEMPLATE=true
+CURRICULUM_STEP=45000
+CURRICULUM_MIN=64
+TAG="${CONFIG}_s${CURRICULUM_MIN}to${SEQ_LEN}_step${CURRICULUM_STEP}"
+
+bash ds_pretrain_gpt2.sh $CONFIG $TAG $MODEL_SIZE $LR $BSZ $SEQ_LEN $MP_SIZE $SEED $SAVE_INTERVAL $NUM_ITER $NUM_TOKEN $LR_DECAY_TOKEN $LR_WARMUP_ITER $CONFIG_TEMPLATE $CURRICULUM_STEP $CURRICULUM_MIN
diff --git a/megatron_lm/curriculum_learning/ds_zero_stage_2_config_baseline.json b/megatron_lm/curriculum_learning/ds_zero_stage_2_config_baseline.json
new file mode 100644
index 0000000..e2f1866
--- /dev/null
+++ b/megatron_lm/curriculum_learning/ds_zero_stage_2_config_baseline.json
@@ -0,0 +1,31 @@
+{
+ "train_batch_size": 512,
+ "gradient_accumulation_steps": 1,
+ "steps_per_print": 1,
+ "zero_optimization": {
+ "stage": 2,
+ "allgather_partitions": true,
+ "reduce_scatter": true,
+ "allgather_bucket_size": 50000000,
+ "reduce_bucket_size": 50000000,
+ "overlap_comm": true
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 0.00015,
+ "max_grad_norm": 1.0,
+ "betas": [0.9, 0.95]
+ }
+ },
+ "gradient_clipping": 1.0,
+ "fp16": {
+ "enabled": true,
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "wall_clock_breakdown": false,
+ "zero_allow_untested_optimizer": false
+}
diff --git a/megatron_lm/curriculum_learning/ds_zero_stage_2_config_curriculum_fixed_linear.json b/megatron_lm/curriculum_learning/ds_zero_stage_2_config_curriculum_fixed_linear.json
new file mode 100644
index 0000000..e46144c
--- /dev/null
+++ b/megatron_lm/curriculum_learning/ds_zero_stage_2_config_curriculum_fixed_linear.json
@@ -0,0 +1,42 @@
+{
+ "train_batch_size": 512,
+ "gradient_accumulation_steps": 1,
+ "steps_per_print": 1,
+ "zero_optimization": {
+ "stage": 2,
+ "allgather_partitions": true,
+ "reduce_scatter": true,
+ "allgather_bucket_size": 50000000,
+ "reduce_bucket_size": 50000000,
+ "overlap_comm": true
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 0.00015,
+ "max_grad_norm": 1.0,
+ "betas": [0.9, 0.95]
+ }
+ },
+ "gradient_clipping": 1.0,
+ "fp16": {
+ "enabled": true,
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "wall_clock_breakdown": false,
+ "zero_allow_untested_optimizer": false,
+ "curriculum_learning": {
+ "enabled": true,
+ "curriculum_type": "seqlen",
+ "min_difficulty": CONFIG_CL_MIN,
+ "max_difficulty": CONFIG_CL_MAX,
+ "schedule_type": "fixed_linear",
+ "schedule_config": {
+ "total_curriculum_step": CONFIG_CL_DURATION,
+ "difficulty_step": 8
+ }
+ }
+}
diff --git a/megatron_lm/examples/ds_pretrain_gpt2-zero2.sh b/megatron_lm/examples/ds_pretrain_gpt2-zero2.sh
new file mode 100755
index 0000000..36a0403
--- /dev/null
+++ b/megatron_lm/examples/ds_pretrain_gpt2-zero2.sh
@@ -0,0 +1,164 @@
+#! /bin/bash
+
+# Change for multinode config
+MP_SIZE=1
+
+DEBUG=1
+if [[ ${DEBUG} == 1 ]]; then
+ MP_SIZE=1
+ NUM_WORKERS=1
+ NUM_GPUS_PER_WORKER=1
+ HIDDEN_SIZE=1024
+ NUM_ATTN_HEADS=16
+ NUM_LAYERS=24
+ BATCHSIZE=4
+else
+ NUM_WORKERS=${DLTS_NUM_WORKER}
+ NUM_GPUS_PER_WORKER=${DLTS_NUM_GPU_PER_WORKER}
+ HIDDEN_SIZE=8192
+ NUM_ATTN_HEADS=32
+ NUM_LAYERS=50
+ BATCHSIZE=4
+
+ #HIDDEN_SIZE=4096
+ #NUM_LAYERS=24 # 50
+ #BATCHSIZE=16
+fi
+
+
+BASE_DATA_PATH=/data/Megatron-LM/data
+DATA_PATH=${BASE_DATA_PATH}/indexed_datasets/megatron
+VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json
+MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt
+CHECKPOINT_PATH=checkpoints/gpt2_345m_ds
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+if [[ -z $1 ]]; then
+ config_json="$script_dir/ds_zero_stage_2_config.json"
+
+ # offloads to NVMe
+ #config_json="$script_dir/ds_zero_stage_infinity_config.json"
+else
+ config_json=$script_dir/`basename $1`
+fi
+
+#ZeRO Configs
+stage=2
+reduce_scatter=true
+contigious_gradients=true
+rbs=50000000
+agbs=5000000000
+
+#Activation Checkpointing and Contigious Memory
+chkp_layers=1
+PA=true
+PA_CPU=true
+CC=true
+SYNCHRONIZE=true
+PROFILE=false
+
+# TiledLinear splits, 0 is disable
+TILED_LINEAR="false"
+TILE_DIM=1
+
+
+# Megatron Model Parallelism
+LOGDIR="tboard-zero2/stage${stage}-lazyscatter-${NUM_LAYERS}l_${HIDDEN_SIZE}h_${NUM_WORKERS}n_${NUM_GPUS_PER_WORKER}g_${MP_SIZE}mp_${BATCHSIZE}b"
+
+
+gpt_options=" \
+ --model-parallel-size ${MP_SIZE} \
+ --num-layers $NUM_LAYERS \
+ --hidden-size $HIDDEN_SIZE \
+ --num-attention-heads ${NUM_ATTN_HEADS} \
+ --seq-length 1024 \
+ --max-position-embeddings 1024 \
+ --batch-size $BATCHSIZE \
+ --train-iters 320000 \
+ --lr-decay-iters 320000 \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --vocab-file $VOCAB_PATH \
+ --merge-file $MERGE_PATH \
+ --data-impl mmap \
+ --split 949,50,1 \
+ --distributed-backend nccl \
+ --lr 1.5e-4 \
+ --lr-decay-style cosine \
+ --min-lr 1.0e-5 \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --warmup 0.01 \
+ --checkpoint-activations \
+ --log-interval 1 \
+ --save-interval 10000 \
+ --eval-interval 2000 \
+ --eval-iters 10 \
+ --fp16 \
+ --scattered-embeddings \
+ --split-transformers \
+"
+ #--tensorboard-dir ${LOGDIR}
+
+ deepspeed_options=" \
+ --deepspeed \
+ --deepspeed_config ${config_json} \
+ --zero-stage ${stage} \
+ --zero-reduce-bucket-size ${rbs} \
+ --zero-allgather-bucket-size ${agbs}
+ "
+
+if [ "${contigious_gradients}" = "true" ]; then
+deepspeed_options="${deepspeed_options} \
+ --zero-contigious-gradients"
+fi
+
+if [ "${reduce_scatter}" = "true" ]; then
+deepspeed_options="${deepspeed_options} \
+ --zero-reduce-scatter"
+fi
+
+chkp_opt=" \
+--deepspeed-activation-checkpointing \
+--checkpoint-num-layers ${chkp_layers}"
+
+if [ "${PA}" = "true" ]; then
+chkp_opt="${chkp_opt} --partition-activations"
+fi
+
+if [ "${PA_CPU}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --checkpoint-in-cpu"
+fi
+
+if [ "${SYNCHRONIZE}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --synchronize-each-layer"
+fi
+
+if [ "${CC}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --contigious-checkpointing"
+fi
+
+if [ "${PROFILE}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --profile-backward"
+fi
+
+if [ "${TILED_LINEAR}" = "true" ]; then
+tile_opt="${tile_opt} \
+ --memory-centric-tiled-linear \
+ --tile-factor=${TILE_DIM}"
+fi
+
+
+full_options="${gpt_options} ${deepspeed_options} ${chkp_opt} ${tile_opt}"
+
+run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py ${@:2} ${full_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/megatron_lm/examples/ds_pretrain_gpt2-zero3.sh b/megatron_lm/examples/ds_pretrain_gpt2-zero3.sh
new file mode 100755
index 0000000..1eaaa9b
--- /dev/null
+++ b/megatron_lm/examples/ds_pretrain_gpt2-zero3.sh
@@ -0,0 +1,164 @@
+#! /bin/bash
+
+# Change for multinode config
+MP_SIZE=1
+
+DEBUG=1
+if [[ ${DEBUG} == 1 ]]; then
+ MP_SIZE=1
+ NUM_WORKERS=1
+ NUM_GPUS_PER_WORKER=1
+ HIDDEN_SIZE=1024
+ NUM_ATTN_HEADS=16
+ NUM_LAYERS=5
+ BATCHSIZE=4
+else
+ NUM_WORKERS=${DLTS_NUM_WORKER}
+ NUM_GPUS_PER_WORKER=${DLTS_NUM_GPU_PER_WORKER}
+ HIDDEN_SIZE=8192
+ NUM_ATTN_HEADS=32
+ NUM_LAYERS=50
+ BATCHSIZE=4
+
+ #HIDDEN_SIZE=4096
+ #NUM_LAYERS=24 # 50
+ #BATCHSIZE=16
+fi
+
+
+BASE_DATA_PATH=/data/Megatron-LM/data
+DATA_PATH=${BASE_DATA_PATH}/indexed_datasets/megatron
+VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json
+MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt
+CHECKPOINT_PATH=checkpoints/gpt2_345m_ds
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+if [[ -z $1 ]]; then
+ config_json="$script_dir/ds_zero_stage_3_config.json"
+
+ # offloads to NVMe
+ #config_json="$script_dir/ds_zero_stage_infinity_config.json"
+else
+ config_json=$script_dir/`basename $1`
+fi
+
+#ZeRO Configs
+stage=3
+reduce_scatter=true
+contigious_gradients=true
+rbs=50000000
+agbs=5000000000
+
+#Activation Checkpointing and Contigious Memory
+chkp_layers=1
+PA=true
+PA_CPU=true
+CC=true
+SYNCHRONIZE=true
+PROFILE=false
+
+# TiledLinear splits, 0 is disable
+TILED_LINEAR="false"
+TILE_DIM=1
+
+
+# Megatron Model Parallelism
+LOGDIR="tboard-zero3/stage${stage}-lazyscatter-${NUM_LAYERS}l_${HIDDEN_SIZE}h_${NUM_WORKERS}n_${NUM_GPUS_PER_WORKER}g_${MP_SIZE}mp_${BATCHSIZE}b"
+
+
+gpt_options=" \
+ --model-parallel-size ${MP_SIZE} \
+ --num-layers $NUM_LAYERS \
+ --hidden-size $HIDDEN_SIZE \
+ --num-attention-heads ${NUM_ATTN_HEADS} \
+ --seq-length 1024 \
+ --max-position-embeddings 1024 \
+ --batch-size $BATCHSIZE \
+ --train-iters 320000 \
+ --lr-decay-iters 320000 \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --vocab-file $VOCAB_PATH \
+ --merge-file $MERGE_PATH \
+ --data-impl mmap \
+ --split 949,50,1 \
+ --distributed-backend nccl \
+ --lr 1.5e-4 \
+ --lr-decay-style cosine \
+ --min-lr 1.0e-5 \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --warmup 0.01 \
+ --checkpoint-activations \
+ --log-interval 1 \
+ --save-interval 10000 \
+ --eval-interval 2000 \
+ --eval-iters 10 \
+ --fp16 \
+ --scattered-embeddings \
+ --split-transformers \
+"
+ #--tensorboard-dir ${LOGDIR}
+
+ deepspeed_options=" \
+ --deepspeed \
+ --deepspeed_config ${config_json} \
+ --zero-stage ${stage} \
+ --zero-reduce-bucket-size ${rbs} \
+ --zero-allgather-bucket-size ${agbs}
+ "
+
+if [ "${contigious_gradients}" = "true" ]; then
+deepspeed_options="${deepspeed_options} \
+ --zero-contigious-gradients"
+fi
+
+if [ "${reduce_scatter}" = "true" ]; then
+deepspeed_options="${deepspeed_options} \
+ --zero-reduce-scatter"
+fi
+
+chkp_opt=" \
+--deepspeed-activation-checkpointing \
+--checkpoint-num-layers ${chkp_layers}"
+
+if [ "${PA}" = "true" ]; then
+chkp_opt="${chkp_opt} --partition-activations"
+fi
+
+if [ "${PA_CPU}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --checkpoint-in-cpu"
+fi
+
+if [ "${SYNCHRONIZE}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --synchronize-each-layer"
+fi
+
+if [ "${CC}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --contigious-checkpointing"
+fi
+
+if [ "${PROFILE}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --profile-backward"
+fi
+
+if [ "${TILED_LINEAR}" = "true" ]; then
+tile_opt="${tile_opt} \
+ --memory-centric-tiled-linear \
+ --tile-factor=${TILE_DIM}"
+fi
+
+
+full_options="${gpt_options} ${deepspeed_options} ${chkp_opt} ${tile_opt}"
+
+run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py ${@:2} ${full_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/megatron_lm/examples/ds_pretrain_gpt2.sh b/megatron_lm/examples/ds_pretrain_gpt2.sh
new file mode 100755
index 0000000..3baad70
--- /dev/null
+++ b/megatron_lm/examples/ds_pretrain_gpt2.sh
@@ -0,0 +1,133 @@
+#! /bin/bash
+
+GPUS_PER_NODE=8
+# Change for multinode config
+MASTER_ADDR=localhost
+MASTER_PORT=6000
+NNODES=1
+NODE_RANK=0
+WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
+
+export DLWS_NUM_WORKER=${NNODES}
+export DLWS_NUM_GPU_PER_WORKER=${GPUS_PER_NODE}
+
+DATA_PATH=/data/megatron-data/indexed/my-gpt2_text_document
+VOCAB_PATH=/data/megatron-data/gpt2-vocab.json
+MERGE_PATH=/data/megatron-data/gpt2-merges.txt
+CHECKPOINT_PATH=checkpoints/gpt2_345m_ds
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+config_json="$script_dir/ds_zero_stage_2_config.json"
+
+# Megatron Model Parallelism
+mp_size=4
+
+NLAYERS=24
+NHIDDEN=1024
+BATCHSIZE=9
+LOGDIR="tensorboard_data/${NLAYERS}l_${NHIDDEN}h_${NNODES}n_${GPUS_PER_NODE}g_${mp_size}mp_${BATCHSIZE}b_ds4"
+
+#ZeRO Configs
+stage=0
+reduce_scatter=true
+contigious_gradients=true
+rbs=50000000
+agbs=5000000000
+
+#Actication Checkpointing and Contigious Memory
+chkp_layers=1
+PA=true
+PA_CPU=false
+CC=true
+SYNCHRONIZE=true
+PROFILE=false
+
+
+gpt_options=" \
+ --model-parallel-size ${mp_size} \
+ --num-layers $NLAYERS \
+ --hidden-size $NHIDDEN \
+ --num-attention-heads 16 \
+ --seq-length 1024 \
+ --max-position-embeddings 1024 \
+ --batch-size $BATCHSIZE \
+ --train-iters 320000 \
+ --lr-decay-iters 320000 \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --vocab-file $VOCAB_PATH \
+ --merge-file $MERGE_PATH \
+ --data-impl mmap \
+ --split 949,50,1 \
+ --distributed-backend nccl \
+ --lr 1.5e-4 \
+ --lr-decay-style cosine \
+ --min-lr 1.0e-5 \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --warmup 0.01 \
+ --checkpoint-activations \
+ --log-interval 100 \
+ --save-interval 10000 \
+ --eval-interval 1000 \
+ --eval-iters 10 \
+ --fp16 \
+ --tensorboard-dir ${LOGDIR}
+"
+
+ deepspeed_options=" \
+ --deepspeed \
+ --deepspeed_config ${config_json} \
+ --zero-stage ${stage} \
+ --zero-reduce-bucket-size ${rbs} \
+ --zero-allgather-bucket-size ${agbs}
+ "
+
+if [ "${contigious_gradients}" = "true" ]; then
+deepspeed_options="${deepspeed_options} \
+ --zero-contigious-gradients"
+fi
+
+if [ "${reduce_scatter}" = "true" ]; then
+deepspeed_options="${deepspeed_options} \
+ --zero-reduce-scatter"
+fi
+
+chkp_opt=" \
+--checkpoint-activations \
+--checkpoint-num-layers ${chkp_layers}"
+
+if [ "${PA}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --partition-activations"
+fi
+
+if [ "${PA_CPU}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --checkpoint-in-cpu"
+fi
+
+if [ "${SYNCHRONIZE}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --synchronize-each-layer"
+fi
+
+if [ "${CC}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --contigious-checkpointing"
+fi
+
+if [ "${PROFILE}" = "true" ]; then
+chkp_opt="${chkp_opt} \
+ --profile-backward"
+fi
+
+full_options="${gpt_options} ${deepspeed_options} ${chkp_opt}"
+
+run_cmd="deepspeed --num_nodes ${DLWS_NUM_WORKER} --num_gpus ${DLWS_NUM_GPU_PER_WORKER} pretrain_gpt2.py $@ ${full_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/megatron_lm/examples/ds_zero_stage_2_config.json b/megatron_lm/examples/ds_zero_stage_2_config.json
new file mode 100755
index 0000000..2ab86c2
--- /dev/null
+++ b/megatron_lm/examples/ds_zero_stage_2_config.json
@@ -0,0 +1,32 @@
+{
+ "train_batch_size": 2048,
+ "gradient_accumulation_steps": 1,
+ "steps_per_print": 1,
+ "zero_optimization": {
+ "stage": 2,
+ "allgather_partitions": true,
+ "reduce_scatter": true,
+ "allgather_bucket_size": 50000000,
+ "reduce_bucket_size": 50000000,
+ "overlap_comm": true
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 0.00015,
+ "max_grad_norm": 1.0,
+ "betas": [0.9, 0.95]
+ }
+ },
+ "gradient_clipping": 1.0,
+ "fp16": {
+ "enabled": true,
+
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "wall_clock_breakdown": true,
+ "zero_allow_untested_optimizer": false
+}
diff --git a/megatron_lm/examples/ds_zero_stage_3_config.json b/megatron_lm/examples/ds_zero_stage_3_config.json
new file mode 100755
index 0000000..d9a90d8
--- /dev/null
+++ b/megatron_lm/examples/ds_zero_stage_3_config.json
@@ -0,0 +1,24 @@
+{
+ "train_batch_size": 64,
+ "gradient_accumulation_steps": 1,
+ "steps_per_print": 1,
+ "zero_optimization": {
+ "stage": 3,
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_prefetch_bucket_size": 1e7,
+ "stage3_param_persitence_threshold": 1e5,
+ "reduce_bucket_size": 1e7,
+ "contiguous_gradients": true
+ },
+ "gradient_clipping": 1.0,
+ "fp16": {
+ "enabled": true,
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "wall_clock_breakdown": true,
+ "zero_allow_untested_optimizer": false
+}
diff --git a/megatron_lm/examples/ds_zero_stage_3_config_release.json b/megatron_lm/examples/ds_zero_stage_3_config_release.json
new file mode 100755
index 0000000..5c94b95
--- /dev/null
+++ b/megatron_lm/examples/ds_zero_stage_3_config_release.json
@@ -0,0 +1,29 @@
+{
+ "train_batch_size": 64,
+ "gradient_accumulation_steps": 1,
+ "steps_per_print": 1,
+ "zero_optimization": {
+ "stage": 3,
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e8,
+ "stage3_param_persitance_threshold": 1e5,
+ "stage3_prefetch_bucket_size": 5e7,
+ "contiguous_gradients": true,
+ "cpu_offload": true,
+ "cpu_offload_params": true,
+ "cpu_offload_use_pin_memory": true,
+ "overlap_comm": true,
+ "reduce_bucket_size": 90000000,
+ "sub_group_size": 4e8
+ },
+ "gradient_clipping": 1.0,
+ "fp16": {
+ "enabled": true,
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "wall_clock_breakdown": true,
+ "zero_allow_untested_optimizer": false
+} \ No newline at end of file
diff --git a/megatron_lm/examples/ds_zero_stage_infinity_config.json b/megatron_lm/examples/ds_zero_stage_infinity_config.json
new file mode 100755
index 0000000..460476a
--- /dev/null
+++ b/megatron_lm/examples/ds_zero_stage_infinity_config.json
@@ -0,0 +1,47 @@
+{
+ "train_micro_batch_size_per_gpu": 4,
+ "gradient_accumulation_steps": 1,
+ "steps_per_print": 1,
+ "zero_optimization": {
+ "stage": 3,
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_param_persitence_threshold": 1e5,
+ "stage3_prefetch_bucket_size": 5e7,
+ "contiguous_gradients": true,
+ "overlap_comm": true,
+ "reduce_bucket_size": 90000000,
+ "sub_group_size": 1e8,
+ "offload_optimizer": {
+ "device": "nvme",
+ "nvme_path": "/nvme_data",
+ "buffer_count": 4,
+ "pipeline_read": false,
+ "pipeline_write": false,
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "nvme",
+ "nvme_path": "/nvme_data",
+ "max_in_cpu": 1,
+ "pin_memory": true
+ }
+ },
+ "gradient_clipping": 1.0,
+ "fp16": {
+ "enabled": true,
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "wall_clock_breakdown": true,
+ "zero_allow_untested_optimizer": false,
+ "aio": {
+ "block_size": 1048576,
+ "queue_depth": 16,
+ "single_submit": false,
+ "overlap_events": true,
+ "thread_count": 2
+ }
+}
diff --git a/megatron_lm/examples/evaluate_zeroshot_gpt2.sh b/megatron_lm/examples/evaluate_zeroshot_gpt2.sh
new file mode 100755
index 0000000..f4f9f22
--- /dev/null
+++ b/megatron_lm/examples/evaluate_zeroshot_gpt2.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+WORLD_SIZE=8
+
+DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr localhost \
+ --master_port 6000"
+
+TASK="LAMBADA"
+
+VALID_DATA=<lambada path>
+VOCAB_FILE=gpt2-vocab.json
+MERGE_FILE=gpt2-merges.txt
+CHECKPOINT=checkpoints/gpt2_345m
+
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
+ --task $TASK \
+ --valid-data $VALID_DATA \
+ --tokenizer-type GPT2BPETokenizer \
+ --strict-lambada \
+ --vocab-file $VOCAB_FILE \
+ --merge-file $MERGE_FILE \
+ --load $CHECKPOINT \
+ --model-parallel-size 1 \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --batch-size 8 \
+ --checkpoint-activations \
+ --seq-length 1024 \
+ --max-position-embeddings 1024 \
+ --log-interval 10 \
+ --fp16 \
+ --no-load-optim \
+ --no-load-rng
diff --git a/megatron_lm/examples/finetune_mnli_distributed.sh b/megatron_lm/examples/finetune_mnli_distributed.sh
new file mode 100755
index 0000000..65f3a9f
--- /dev/null
+++ b/megatron_lm/examples/finetune_mnli_distributed.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+
+WORLD_SIZE=8
+
+DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr localhost \
+ --master_port 6000"
+
+TRAIN_DATA="data/glue_data/MNLI/train.tsv"
+VALID_DATA="data/glue_data/MNLI/dev_matched.tsv \
+ data/glue_data/MNLI/dev_mismatched.tsv"
+PRETRAINED_CHECKPOINT=checkpoints/bert_345m
+VOCAB_FILE=bert-vocab.txt
+CHECKPOINT_PATH=checkpoints/bert_345m_mnli
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
+ --task MNLI \
+ --seed 1234 \
+ --train-data $TRAIN_DATA \
+ --valid-data $VALID_DATA \
+ --tokenizer-type BertWordPieceLowerCase \
+ --vocab-file $VOCAB_FILE \
+ --epochs 5 \
+ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
+ --model-parallel-size 1 \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --batch-size 8 \
+ --checkpoint-activations \
+ --lr 5.0e-5 \
+ --lr-decay-style linear \
+ --warmup 0.065 \
+ --seq-length 512 \
+ --max-position-embeddings 512 \
+ --save-interval 500000 \
+ --save $CHECKPOINT_PATH \
+ --log-interval 10 \
+ --eval-interval 100 \
+ --eval-iters 50 \
+ --weight-decay 1.0e-1 \
+ --fp16
diff --git a/megatron_lm/examples/finetune_race_distributed.sh b/megatron_lm/examples/finetune_race_distributed.sh
new file mode 100755
index 0000000..0212ecb
--- /dev/null
+++ b/megatron_lm/examples/finetune_race_distributed.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+
+WORLD_SIZE=8
+
+DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr localhost \
+ --master_port 6000"
+
+TRAIN_DATA="data/RACE/train/middle"
+VALID_DATA="data/RACE/dev/middle \
+ data/RACE/dev/high"
+VOCAB_FILE=bert-vocab.txt
+PRETRAINED_CHECKPOINT=checkpoints/bert_345m
+CHECKPOINT_PATH=checkpoints/bert_345m_race
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
+ --task RACE \
+ --seed 1234 \
+ --train-data $TRAIN_DATA \
+ --valid-data $VALID_DATA \
+ --tokenizer-type BertWordPieceLowerCase \
+ --vocab-file $VOCAB_FILE \
+ --epochs 3 \
+ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
+ --model-parallel-size 1 \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --batch-size 4 \
+ --checkpoint-activations \
+ --lr 1.0e-5 \
+ --lr-decay-style linear \
+ --warmup 0.06 \
+ --seq-length 512 \
+ --max-position-embeddings 512 \
+ --save-interval 100000 \
+ --save $CHECKPOINT_PATH \
+ --log-interval 10 \
+ --eval-interval 100 \
+ --eval-iters 50 \
+ --weight-decay 1.0e-1 \
+ --clip-grad 1.0 \
+ --hidden-dropout 0.1 \
+ --attention-dropout 0.1 \
+ --fp16
diff --git a/megatron_lm/examples/generate_text.sh b/megatron_lm/examples/generate_text.sh
new file mode 100755
index 0000000..6a04c49
--- /dev/null
+++ b/megatron_lm/examples/generate_text.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+CHECKPOINT_PATH=checkpoints/gpt2_345m
+VOCAB_FILE=gpt2-vocab.json
+MERGE_FILE=gpt2-merges.txt
+
+python tools/generate_samples_gpt2.py \
+ --model-parallel-size 1 \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --load $CHECKPOINT_PATH \
+ --num-attention-heads 16 \
+ --max-position-embeddings 1024 \
+ --tokenizer-type GPT2BPETokenizer \
+ --fp16 \
+ --batch-size 2 \
+ --seq-length 1024 \
+ --out-seq-length 1024 \
+ --temperature 1.0 \
+ --vocab-file $VOCAB_FILE \
+ --merge-file $MERGE_FILE \
+ --genfile unconditional_samples.json \
+ --num-samples 2 \
+ --top_p 0.9 \
+ --recompute
diff --git a/megatron_lm/examples/merge_mp_bert.sh b/megatron_lm/examples/merge_mp_bert.sh
new file mode 100755
index 0000000..01e08b1
--- /dev/null
+++ b/megatron_lm/examples/merge_mp_bert.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+
+MODEL_PARALLEL_SIZE=2
+
+VOCAB_FILE=bert-vocab.txt
+CHECKPOINT_PATH=checkpoints/bert_345m
+
+WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
+ --model-type BERT \
+ --model-parallel-size $MODEL_PARALLEL_SIZE \
+ --tokenizer-type BertWordPieceLowerCase \
+ --vocab-file $VOCAB_FILE \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --seq-length 512 \
+ --max-position-embeddings 512 \
+ --load $CHECKPOINT_PATH
diff --git a/megatron_lm/examples/pretrain_bert.sh b/megatron_lm/examples/pretrain_bert.sh
new file mode 100755
index 0000000..ecf5947
--- /dev/null
+++ b/megatron_lm/examples/pretrain_bert.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+RANK=0
+WORLD_SIZE=1
+DATA_PATH=<Specify path and file prefix>_text_sentence
+CHECKPOINT_PATH=<Specify path>
+
+python pretrain_bert.py \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --batch-size 4 \
+ --seq-length 512 \
+ --max-position-embeddings 512 \
+ --train-iters 2000000 \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --vocab-file bert-vocab.txt \
+ --data-impl mmap \
+ --split 949,50,1 \
+ --distributed-backend nccl \
+ --lr 0.0001 \
+ --min-lr 0.00001 \
+ --lr-decay-style linear \
+ --lr-decay-iters 990000 \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --warmup .01 \
+ --log-interval 100 \
+ --save-interval 10000 \
+ --eval-interval 1000 \
+ --eval-iters 10 \
+ --fp16
+
diff --git a/megatron_lm/examples/pretrain_bert_distributed.sh b/megatron_lm/examples/pretrain_bert_distributed.sh
new file mode 100755
index 0000000..17ebae1
--- /dev/null
+++ b/megatron_lm/examples/pretrain_bert_distributed.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+
+GPUS_PER_NODE=8
+# Change for multinode config
+MASTER_ADDR=localhost
+MASTER_PORT=6000
+NNODES=1
+NODE_RANK=0
+WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
+
+DATA_PATH=<Specify path and file prefix>_text_sentence
+CHECKPOINT_PATH=<Specify path>
+
+DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS \
+ pretrain_bert.py \
+ --model-parallel-size 1 \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --batch-size 4 \
+ --seq-length 512 \
+ --max-position-embeddings 512 \
+ --train-iters 1000000 \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --vocab-file bert-vocab.txt \
+ --data-impl mmap \
+ --split 949,50,1 \
+ --distributed-backend nccl \
+ --lr 0.0001 \
+ --lr-decay-style linear \
+ --min-lr 1.0e-5 \
+ --lr-decay-iters 990000 \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --warmup .01 \
+ --log-interval 100 \
+ --save-interval 10000 \
+ --eval-interval 1000 \
+ --eval-iters 10 \
+ --fp16
diff --git a/megatron_lm/examples/pretrain_gpt2.sh b/megatron_lm/examples/pretrain_gpt2.sh
new file mode 100755
index 0000000..66232bf
--- /dev/null
+++ b/megatron_lm/examples/pretrain_gpt2.sh
@@ -0,0 +1,43 @@
+#! /bin/bash
+
+# Runs the "345M" parameter model
+
+RANK=0
+WORLD_SIZE=1
+
+DATA_PATH=<Specify path and file prefix>_text_document
+CHECKPOINT_PATH=<Specify path>
+
+
+python pretrain_gpt2.py \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --batch-size 8 \
+ --seq-length 1024 \
+ --max-position-embeddings 1024 \
+ --train-iters 500000 \
+ --lr-decay-iters 320000 \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --vocab-file gpt2-vocab.json \
+ --merge-file gpt2-merges.txt \
+ --data-impl mmap \
+ --split 949,50,1 \
+ --distributed-backend nccl \
+ --lr 0.00015 \
+ --min-lr 1.0e-5 \
+ --lr-decay-style cosine \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --warmup .01 \
+ --checkpoint-activations \
+ --log-interval 100 \
+ --save-interval 10000 \
+ --eval-interval 1000 \
+ --eval-iters 10 \
+ --fp16
+
+
+set +x
diff --git a/megatron_lm/examples/pretrain_gpt2_distributed.sh b/megatron_lm/examples/pretrain_gpt2_distributed.sh
new file mode 100755
index 0000000..1d74625
--- /dev/null
+++ b/megatron_lm/examples/pretrain_gpt2_distributed.sh
@@ -0,0 +1,52 @@
+#! /bin/bash
+
+# Runs the "345M" parameter model
+
+GPUS_PER_NODE=8
+# Change for multinode config
+MASTER_ADDR=localhost
+MASTER_PORT=6000
+NNODES=1
+NODE_RANK=0
+WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
+
+DATA_PATH=<Specify path and file prefix>_text_document
+CHECKPOINT_PATH=<Specify path>
+
+DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS \
+ pretrain_gpt2.py \
+ --model-parallel-size 1 \
+ --num-layers 24 \
+ --hidden-size 1024 \
+ --num-attention-heads 16 \
+ --batch-size 8 \
+ --seq-length 1024 \
+ --max-position-embeddings 1024 \
+ --train-iters 500000 \
+ --lr-decay-iters 320000 \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --vocab-file gpt2-vocab.json \
+ --merge-file gpt2-merges.txt \
+ --data-impl mmap \
+ --split 949,50,1 \
+ --distributed-backend nccl \
+ --lr 0.00015 \
+ --lr-decay-style cosine \
+ --min-lr 1.0e-5 \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --warmup .01 \
+ --checkpoint-activations \
+ --log-interval 100 \
+ --save-interval 10000 \
+ --eval-interval 1000 \
+ --eval-iters 10 \
+ --fp16
+
+
+
+set +x
diff --git a/megatron_lm/images/Makefile b/megatron_lm/images/Makefile
new file mode 100644
index 0000000..5efde02
--- /dev/null
+++ b/megatron_lm/images/Makefile
@@ -0,0 +1,12 @@
+default: cases.png scaling-mp.png scaling-dp.png
+
+# for some reason the size option to convert in scaling.tex doesn't work, manually do it after
+cases.png scaling-mp.png scaling-dp.png: tables.tex
+ latex --shell-escape $<
+ convert tables-1.png -resize 650 cases.png
+ convert tables-2.png -resize 600 scaling-mp.png
+ convert tables-3.png -resize 350 scaling-dp.png
+
+clean:
+ rm -rf *.aux *.log *.dvi *.ps
+ rm -rf tables-*.png
diff --git a/megatron_lm/images/cases.png b/megatron_lm/images/cases.png
new file mode 100644
index 0000000..8f52c38
--- /dev/null
+++ b/megatron_lm/images/cases.png
Binary files differ
diff --git a/megatron_lm/images/scaling-dp.png b/megatron_lm/images/scaling-dp.png
new file mode 100644
index 0000000..ce3ae95
--- /dev/null
+++ b/megatron_lm/images/scaling-dp.png
Binary files differ
diff --git a/megatron_lm/images/scaling-mp.png b/megatron_lm/images/scaling-mp.png
new file mode 100644
index 0000000..75f3498
--- /dev/null
+++ b/megatron_lm/images/scaling-mp.png
Binary files differ
diff --git a/megatron_lm/images/tables.tex b/megatron_lm/images/tables.tex
new file mode 100644
index 0000000..86d744e
--- /dev/null
+++ b/megatron_lm/images/tables.tex
@@ -0,0 +1,40 @@
+\documentclass[multi,convert]{standalone}
+\usepackage{multirow}
+\standaloneenv{tabular}
+
+\begin{document}
+
+\begin{tabular}{cccccc}
+ Case & Hidden Size & Attention Heads & Layers & Parameters (billions) & Model Parallel Partitions \\
+ \hline
+ 1B & 1920 & 15 & 24 & 1.16 & 1 \\
+ 2B & 2304 & 18 & 30 & 2.03 & 2 \\
+ 4B & 3072 & 24 & 36 & 4.24 & 4 \\
+ 8B & 4096 & 32 & 42 & 8.67 & 8 \\
+\end{tabular}
+
+\begin{tabular}{cc|ccc|ccc}
+ & & \multicolumn{3}{c|}{\textbf{DGX-2 (V100) batch size 8}} & \multicolumn{3}{c}{\textbf{DGX-A100 batch size 16}} \\
+ \hline
+ \multirow{2}{*}{Case} & Number of & Iteration & \multirow{2}{*}{Scaling} & TeraFLOPs & Iteration & \multirow{2}{*}{Scaling} & TeraFLOPs \\
+ & GPUs & Time (ms) & & per GPU & Time (ms) & & per GPU \\
+ \hline
+ 1B & 1 & 1121 & 100.0\% & 71.9 & 1076 & 100\% & 149.8 \\
+ 2B & 2 & 1093 & 89.6\% & 64.2 & 1026 & 91.7\% & 136.8 \\
+ 4B & 4 & 1238 & 82.5\% & 58.5 & 1162 & 84.5\% & 124.7 \\
+ 8B & 8 & 1407 & 74.3\% & 52.2 & 1343 & 74.7\% & 109.3 \\
+\end{tabular}
+
+\begin{tabular}{cc|ccc}
+ & & \multicolumn{3}{c}{\textbf{DGX-A100 batch size 2048}} \\
+ \hline
+ \multirow{2}{*}{Case} & Number of & Iteration & \multirow{2}{*}{Scaling} & TeraFLOPs \\
+ & GPUs & Time (ms) & & per GPU \\
+ \hline
+ 1B & 128 & 1153 & 93.3\% & 139.8 \\
+ 2B & 256 & 1101 & 85.5\% & 127.5 \\
+ 4B & 512 & 1242 & 79.0\% & 116.7 \\
+ 8B & 1024 & 1380 & 72.7\% & 106.5 \\
+\end{tabular}
+
+\end{document}
diff --git a/megatron_lm/megatron/__init__.py b/megatron_lm/megatron/__init__.py
new file mode 100644
index 0000000..3f4fbda
--- /dev/null
+++ b/megatron_lm/megatron/__init__.py
@@ -0,0 +1,41 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from .package_info import (
+ __description__,
+ __contact_names__,
+ __url__,
+ __download_url__,
+ __keywords__,
+ __license__,
+ __package_name__,
+ __version__,
+)
+
+from .global_vars import get_args
+from .global_vars import get_tokenizer
+from .global_vars import get_tensorboard_writer
+from .global_vars import get_adlr_autoresume
+from .global_vars import get_timers
+from .initialize import initialize_megatron
+
+def print_rank_0(message):
+ """If distributed is initialized print only on rank 0."""
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_rank() == 0:
+ print(message, flush=True)
+ else:
+ print(message, flush=True)
diff --git a/megatron_lm/megatron/arguments.py b/megatron_lm/megatron/arguments.py
new file mode 100644
index 0000000..c7b7f8d
--- /dev/null
+++ b/megatron_lm/megatron/arguments.py
@@ -0,0 +1,587 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron arguments."""
+
+import argparse
+import os
+
+import torch
+from megatron import fused_kernels
+
+import deepspeed
+
+def parse_args(extra_args_provider=None, defaults={},
+ ignore_unknown_args=False):
+ """Parse all arguments."""
+ parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
+ allow_abbrev=False)
+
+ # Standard arguments.
+ parser = _add_network_size_args(parser)
+ parser = _add_regularization_args(parser)
+ parser = _add_training_args(parser)
+ parser = _add_initialization_args(parser)
+ parser = _add_learning_rate_args(parser)
+ parser = _add_checkpointing_args(parser)
+ parser = _add_mixed_precision_args(parser)
+ parser = _add_distributed_args(parser)
+ parser = _add_validation_args(parser)
+ parser = _add_data_args(parser)
+ parser = _add_autoresume_args(parser)
+ parser = _add_realm_args(parser)
+ parser = _add_zero_args(parser)
+ parser = _add_memoryopt_args(parser)
+ parser = _add_activation_checkpoint_args(parser)
+
+ # Custom arguments.
+ if extra_args_provider is not None:
+ parser = extra_args_provider(parser)
+
+ # Include DeepSpeed configuration arguments
+ parser = deepspeed.add_config_arguments(parser)
+
+ # Parse.
+ if ignore_unknown_args:
+ args, _ = parser.parse_known_args()
+ else:
+ args = parser.parse_args()
+
+ args.tokens = 0
+ # Distributed args.
+ args.rank = int(os.getenv('RANK', '0'))
+ args.world_size = int(os.getenv("WORLD_SIZE", '1'))
+
+ # ??? fix
+ # args.model_parallel_size = min(args.model_parallel_size, args.world_size)
+ # assert args.model_parallel_size <= args.world_size # comment for merge_mp_partitions
+
+ # pad intermediate-size
+ if args.intermediate_size % args.model_parallel_size != 0:
+ new_size = (args.intermediate_size + args.model_parallel_size - 1) // args.model_parallel_size * args.model_parallel_size
+ print(f'!!! Padded intermediate-size {args.intermediate_size} to {new_size}, TP={args.model_parallel_size}')
+ args._intermediate_pad = new_size - args.intermediate_size
+ args.intermediate_size = new_size
+ else:
+ args._intermediate_pad = 0
+
+ if args.rank == 0:
+ print('using world size: {} and model-parallel size: {} '.format(
+ args.world_size, args.model_parallel_size))
+
+ # Fp16 loss scaling.
+ args.dynamic_loss_scale = False
+ if args.loss_scale is None:
+ args.dynamic_loss_scale = True
+
+ # Parameters dtype.
+ args.params_dtype = torch.float
+ if args.fp16:
+ args.params_dtype = torch.half
+ if args.rank == 0:
+ print('using {} for parameters ...'.format(args.params_dtype),
+ flush=True)
+
+
+ # Set input defaults.
+ for key in defaults:
+ # For default to be valid, it should not be provided in the
+ # arguments that are passed to the program. We check this by
+ # ensuring the arg is set to None.
+ if getattr(args, key) is not None:
+ if args.rank == 0:
+ print('WARNING: overriding default arguments for {key}:{v} \
+ with {key}:{v2}'.format(key=key, v=defaults[key],
+ v2=getattr(args, key)),
+ flush=True)
+ else:
+ setattr(args, key, defaults[key])
+
+ # Check required arguments.
+ required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
+ 'max_position_embeddings']
+ for req_arg in required_args:
+ _check_arg_is_not_none(args, req_arg)
+
+ # Checks.
+ assert args.hidden_size % args.num_attention_heads == 0
+ if args.seq_length is not None:
+ assert args.max_position_embeddings >= args.seq_length
+ if args.lr is not None:
+ assert args.min_lr <= args.lr
+ if args.save is not None:
+ assert args.save_interval is not None
+ # Parameters sharing does not work with torch DDP.
+ if (args.num_unique_layers is not None) and (args.num_layers is not None):
+ assert args.num_unique_layers <= args.num_layers
+ assert args.num_layers % args.num_unique_layers == 0, \
+ 'num-layers should be divisible by num-unique-layers.'
+ if args.num_unique_layers < args.num_layers:
+ assert args.DDP_impl == 'local', \
+ 'torch-DDP does not work with parameters sharing.'
+ # Mixed precision checks.
+ if args.fp16_lm_cross_entropy:
+ assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
+ # Activation checkpointing.
+ if args.distribute_checkpointed_activations:
+ assert args.checkpoint_activations, \
+ 'for distribute-checkpointed-activations to work you '\
+ 'need to enable checkpoint-activations'
+
+ # load scaled_upper_triang_masked_softmax_fusion kernel
+ if args.scaled_upper_triang_masked_softmax_fusion:
+ fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
+
+ # load scaled_masked_softmax_fusion kernel
+ if args.scaled_masked_softmax_fusion:
+ fused_kernels.load_scaled_masked_softmax_fusion_kernel()
+
+ _print_args(args)
+ return args
+
+
+def _print_args(args):
+ """Print arguments."""
+ if args.rank == 0:
+ print('-------------------- arguments --------------------', flush=True)
+ str_list = []
+ for arg in vars(args):
+ dots = '.' * (32 - len(arg))
+ str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
+ for arg in sorted(str_list, key=lambda x: x.lower()):
+ print(arg, flush=True)
+ print('---------------- end of arguments ----------------', flush=True)
+
+
+def _check_arg_is_not_none(args, arg):
+ assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
+
+
+def _add_network_size_args(parser):
+ group = parser.add_argument_group(title='network size')
+
+ group.add_argument('--num-layers', type=int, default=None,
+ help='Number of transformer layers.')
+ group.add_argument('--num-unique-layers', type=int, default=None,
+ help='Number of unique transformer layers. '
+ '`num-layers` should be divisible by this value.')
+ group.add_argument('--param-sharing-style', default='grouped',
+ choices=['grouped', 'spaced'],
+ help='Ordering of the shared parameters. For example, '
+ 'for a `num-layers`=4 and `--num-unique-layers`=2, '
+ 'we will have the following ordering for two unique '
+ 'layers 1 and 2: '
+ ' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
+ group.add_argument('--embedding-size', type=int, default=None,
+ help='Tansformer embedding size.')
+ group.add_argument('--hidden-size', type=int, default=None,
+ help='Tansformer hidden size.')
+ group.add_argument('--intermediate-size', type=int, default=None,
+ help='Tansformer intermediate size.')
+ group.add_argument('--activation-type', type=str, default='gelu',
+ choices=['gelu', 'geglu'])
+ group.add_argument('--num-attention-heads', type=int, default=None,
+ help='Number of transformer attention heads.')
+ group.add_argument('--max-position-embeddings', type=int, default=None,
+ help='Maximum number of position embeddings to use. '
+ 'This is the size of position embedding.')
+ group.add_argument('--pos-encoding-type', type=str, default='trainable_absolute',
+ choices=['trainable_absolute', 'rotary'])
+ group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
+ help='Pad the vocab size to be divisible by this value.'
+ 'This is added for computational efficieny reasons.')
+ group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
+ help='Layer norm epsilon.')
+ group.add_argument('--apply-residual-connection-post-layernorm',
+ action='store_true',
+ help='If set, use original BERT residula connection '
+ 'ordering.')
+
+ return parser
+
+
+def _add_regularization_args(parser):
+ group = parser.add_argument_group(title='regularization')
+
+ group.add_argument('--attention-dropout', type=float, default=0.1,
+ help='Post attention dropout ptobability.')
+ group.add_argument('--hidden-dropout', type=float, default=0.1,
+ help='Dropout probability for hidden state transformer.')
+ group.add_argument('--weight-decay', type=float, default=0.01,
+ help='Weight decay coefficient for L2 regularization.')
+ group.add_argument('--clip-grad', type=float, default=1.0,
+ help='Gradient clipping based on global L2 norm.')
+ group.add_argument('--adam-beta1', type=float, default=0.9,
+ help='First coefficient for computing running averages of'
+ 'gradient and its square')
+ group.add_argument('--adam-beta2', type=float, default=0.999,
+ help='Second coefficient for computing running averages of'
+ 'gradient and its square')
+ group.add_argument('--adam-eps', type=float, default=1e-08,
+ help='Term added to the denominator to improve'
+ 'numerical stability')
+
+ return parser
+
+
+def _add_training_args(parser):
+ group = parser.add_argument_group(title='training')
+
+ group.add_argument('--batch-size', type=int, default=None,
+ help='Batch size per model instance (local batch size). '
+ 'Global batch size is local batch size times data '
+ 'parallel size.')
+ group.add_argument('--checkpoint-activations', action='store_true',
+ help='Checkpoint activation to allow for training '
+ 'with larger models, sequences, and batch sizes.')
+ group.add_argument('--distribute-checkpointed-activations',
+ action='store_true',
+ help='If set, distribute checkpointed activations '
+ 'across model parallel group.')
+ group.add_argument('--checkpoint-num-layers', type=int, default=1,
+ help='chunk size (number of layers) for checkpointing.')
+ group.add_argument('--train-iters', type=int, default=None,
+ help='Total number of iterations to train over all '
+ 'training runs.')
+ group.add_argument('--train-tokens', type=int, default=None,
+ help='Total number of tokens to train over all '
+ 'training runs.')
+ group.add_argument('--log-interval', type=int, default=100,
+ help='Report loss and timing interval.')
+ group.add_argument('--exit-interval', type=int, default=None,
+ help='Exit the program after the iteration is divisible '
+ 'by this value.')
+ group.add_argument('--tensorboard-dir', type=str, default=None,
+ help='Write TensorBoard logs to this directory.')
+ group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
+ action='store_true',
+ help='Enable fusion of query_key_value_scaling '
+ 'time (upper diagonal) masking and softmax.')
+ group.add_argument('--scaled-masked-softmax-fusion',
+ action='store_true',
+ help='Enable fusion of query_key_value_scaling '
+ 'general masking and softmax.')
+ group.add_argument('--bias-gelu-fusion', action='store_true',
+ help='Enable bias and gelu fusion.')
+ group.add_argument('--bias-dropout-fusion', action='store_true',
+ help='Enable bias and dropout fusion.')
+
+ group.add_argument('--cpu-optimizer', action='store_true',
+ help='Run optimizer on CPU')
+ group.add_argument('--cpu_torch_adam', action='store_true',
+ help='Use Torch Adam as optimizer on CPU.')
+ return parser
+
+
+def _add_initialization_args(parser):
+ group = parser.add_argument_group(title='initialization')
+
+ group.add_argument('--seed', type=int, default=1234,
+ help='Random seed used for python, numpy, '
+ 'pytorch, and cuda.')
+ group.add_argument('--init-method-std', type=float, default=0.02,
+ help='Standard deviation of the zero mean normal '
+ 'distribution used for weight initialization.')
+
+ return parser
+
+
+def _add_learning_rate_args(parser):
+ group = parser.add_argument_group(title='learning rate')
+
+ group.add_argument('--lr', type=float, default=None,
+ help='Initial learning rate. Depending on decay style '
+ 'and initial warmup, the learing rate at each '
+ 'iteration would be different.')
+ group.add_argument('--lr-decay-style', type=str, default='linear',
+ choices=['constant', 'linear', 'cosine', 'exponential'],
+ help='Learning rate decay function.')
+ group.add_argument('--lr-decay-iters', type=int, default=None,
+ help='number of iterations to decay learning rate over,'
+ ' If None defaults to `--train-iters`')
+ group.add_argument('--lr-decay-tokens', type=int, default=None,
+ help='Learning rate decay tokens.')
+ group.add_argument('--min-lr', type=float, default=0.0,
+ help='Minumum value for learning rate. The scheduler'
+ 'clip values below this threshold.')
+ group.add_argument('--warmup', type=float, default=0.01,
+ help='Percentage of total iterations to warmup on '
+ '(.01 = 1 percent of all training iters).')
+ group.add_argument('--warmup-iters', type=int, default=None,
+ help='Number of iterations for LR warmup.'
+ 'If not None will override `--warmup`')
+ group.add_argument('--override-lr-scheduler', action='store_true',
+ help='Reset the values of the scheduler (learning rate,'
+ 'warmup iterations, minimum learning rate, maximum '
+ 'number of iterations, and decay style from input '
+ 'arguments and ignore values from checkpoints. Note'
+ 'that all the above values will be reset.')
+ group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
+ help='Use checkpoint to set the values of the scheduler '
+ '(learning rate, warmup iterations, minimum learning '
+ 'rate, maximum number of iterations, and decay style '
+ 'from checkpoint and ignore input arguments.')
+
+ return parser
+
+
+def _add_checkpointing_args(parser):
+ group = parser.add_argument_group(title='checkpointing')
+
+ group.add_argument('--save', type=str, default=None,
+ help='Output directory to save checkpoints to.')
+ group.add_argument('--save-interval', type=int, default=None,
+ help='Number of iterations between checkpoint saves.')
+ group.add_argument('--no-save-optim', action='store_true',
+ help='Do not save current optimizer.')
+ group.add_argument('--no-save-rng', action='store_true',
+ help='Do not save current rng state.')
+ group.add_argument('--load', type=str, default=None,
+ help='Directory containing a model checkpoint.')
+ group.add_argument('--load-release-checkpoint', action='store_true',
+ help='Enable new loading release checkpoint.')
+ group.add_argument('--no-load-optim', action='store_true',
+ help='Do not load optimizer when loading checkpoint.')
+ group.add_argument('--no-load-rng', action='store_true',
+ help='Do not load rng state when loading checkpoint.')
+ group.add_argument('--finetune', action='store_true',
+ help='Load model for finetuning. Do not load optimizer '
+ 'or rng state from checkpoint and set iteration to 0. '
+ 'Assumed when loading a release checkpoint.')
+
+ return parser
+
+
+def _add_mixed_precision_args(parser):
+ group = parser.add_argument_group(title='mixed precision')
+
+ group.add_argument('--fp16', action='store_true',
+ help='Run model in fp16 mode.')
+ group.add_argument('--apply-query-key-layer-scaling', action='store_true',
+ help='Scale Q * K^T by 1 / layer-number. If this flag '
+ 'is set, then it will automatically set '
+ 'attention-softmax-in-fp32 to true')
+ group.add_argument('--attention-softmax-in-fp32', action='store_true',
+ help='Run attention masking and softmax in fp32.')
+ group.add_argument('--fp32-allreduce', action='store_true',
+ help='All-reduce in fp32')
+ group.add_argument('--hysteresis', type=int, default=2,
+ help='hysteresis for dynamic loss scaling')
+ group.add_argument('--loss-scale', type=float, default=None,
+ help='Static loss scaling, positive power of 2 '
+ 'values can improve fp16 convergence. If None, dynamic'
+ 'loss scaling is used.')
+ group.add_argument('--loss-scale-window', type=float, default=1000,
+ help='Window over which to raise/lower dynamic scale.')
+ group.add_argument('--min-scale', type=float, default=1,
+ help='Minimum loss scale for dynamic loss scale.')
+ group.add_argument('--fp16-lm-cross-entropy', action='store_true',
+ help='Move the cross entropy unreduced loss calculation'
+ 'for lm head to fp16.')
+
+
+ return parser
+
+
+def _add_distributed_args(parser):
+ group = parser.add_argument_group(title='mixed precision')
+
+ group.add_argument('--model-parallel-size', type=int, default=1,
+ help='Size of the model parallel.')
+ group.add_argument('--distributed-backend', default='nccl',
+ choices=['nccl', 'gloo'],
+ help='Which backend to use for distributed training.')
+ group.add_argument('--DDP-impl', default='local',
+ choices=['local', 'torch'],
+ help='which DistributedDataParallel implementation '
+ 'to use.')
+ group.add_argument('--local_rank', type=int, default=None,
+ help='local rank passed from distributed launcher.')
+ group.add_argument('--lazy-mpu-init', type=bool, required=False,
+ help='If set to True, initialize_megatron() skips DDP initialization'
+ ' and returns function to complete it instead.'
+ 'Also turns on --use-cpu-initialization flag.'
+ 'This is for external DDP manager.' )
+ group.add_argument('--use-cpu-initialization', action='store_true',
+ help='If set, affine parallel weights initialization uses CPU' )
+ return parser
+
+
+def _add_validation_args(parser):
+ group = parser.add_argument_group(title='validation')
+
+ group.add_argument('--eval-iters', type=int, default=100,
+ help='Number of iterations to run for evaluation'
+ 'validation/test for.')
+ group.add_argument('--eval-interval', type=int, default=1000,
+ help='Interval between running evaluation on '
+ 'validation set.')
+
+ return parser
+
+
+def _add_data_args(parser):
+ group = parser.add_argument_group(title='data and dataloader')
+
+ group.add_argument('--data-path', type=str, default=None,
+ help='Path to combined dataset to split.')
+ group.add_argument('--split', type=str, default='969, 30, 1',
+ help='Comma-separated list of proportions for training,'
+ ' validation, and test split. For example the split '
+ '`90,5,5` will use 90% of data for training, 5% for '
+ 'validation and 5% for test.')
+ group.add_argument('--vocab-file', type=str, default=None,
+ help='Path to the vocab file.')
+ group.add_argument('--merge-file', type=str, default=None,
+ help='Path to the BPE merge file.')
+ group.add_argument('--seq-length', type=int, default=None,
+ help="Maximum sequence length to process.")
+ group.add_argument('--mask-prob', type=float, default=0.15,
+ help='Probability of replacing a token with mask.')
+ group.add_argument('--short-seq-prob', type=float, default=0.1,
+ help='Probability of producing a short sequence.')
+ group.add_argument('--mmap-warmup', action='store_true',
+ help='Warm up mmap files.')
+ group.add_argument('--num-workers', type=int, default=2,
+ help="Dataloader number of workers.")
+ group.add_argument('--tokenizer-type', type=str,
+ default=None,
+ choices=['SentencePiece',
+ 'BertWordPieceLowerCase',
+ 'BertWordPieceCase',
+ 'GPT2BPETokenizer'],
+ help='What type of tokenizer to use.')
+ group.add_argument('--data-impl', type=str, default='infer',
+ choices=['lazy', 'cached', 'mmap', 'infer'],
+ help='Implementation of indexed datasets.')
+ group.add_argument('--reset-position-ids', action='store_true',
+ help='Reset posistion ids after end-of-document token.')
+ group.add_argument('--reset-attention-mask', action='store_true',
+ help='Reset self attention maske after '
+ 'end-of-document token.')
+ group.add_argument('--eod-mask-loss', action='store_true',
+ help='Mask loss for the end of document tokens.')
+
+ return parser
+
+
+def _add_autoresume_args(parser):
+ group = parser.add_argument_group(title='autoresume')
+
+ group.add_argument('--adlr-autoresume', action='store_true',
+ help='Enable autoresume on adlr cluster.')
+ group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
+ help='Intervals over which check for autoresume'
+ 'termination signal')
+
+ return parser
+
+
+def _add_realm_args(parser):
+ group = parser.add_argument_group(title='realm')
+
+ # network size
+ group.add_argument('--ict-head-size', type=int, default=None,
+ help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
+
+ # checkpointing
+ group.add_argument('--ict-load', type=str, default=None,
+ help='Directory containing an ICTBertModel checkpoint')
+ group.add_argument('--bert-load', type=str, default=None,
+ help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')
+
+ # data
+ group.add_argument('--titles-data-path', type=str, default=None,
+ help='Path to titles dataset used for ICT')
+ group.add_argument('--query-in-block-prob', type=float, default=0.1,
+ help='Probability of keeping query in block for ICT dataset')
+ group.add_argument('--use-one-sent-docs', action='store_true',
+ help='Whether to use one sentence documents in ICT')
+
+ # training
+ group.add_argument('--report-topk-accuracies', nargs='+', default=[],
+ help="Which top-k accuracies to report (e.g. '1 5 20')")
+
+ # faiss index
+ group.add_argument('--faiss-use-gpu', action='store_true',
+ help='Whether create the FaissMIPSIndex on GPU')
+ group.add_argument('--block-data-path', type=str, default=None,
+ help='Where to save/load BlockData to/from')
+
+ # indexer
+ group.add_argument('--indexer-batch-size', type=int, default=128,
+ help='How large of batches to use when doing indexing jobs')
+ group.add_argument('--indexer-log-interval', type=int, default=1000,
+ help='After how many batches should the indexer report progress')
+ return parser
+
+
+def _add_zero_args(parser):
+ """Text generate arguments."""
+
+ group = parser.add_argument_group('Text generation', 'configurations')
+ group.add_argument("--zero-stage", type=int, default=1.0)
+ group.add_argument('--zero-reduce-scatter', action='store_true',
+ help='Use reduce scatter if specified')
+ group.add_argument('--zero-contigious-gradients', action='store_true',
+ help='Use contigious memory optimizaiton if specified')
+ group.add_argument("--zero-reduce-bucket-size", type=int, default=0.0)
+ group.add_argument("--zero-allgather-bucket-size", type=int, default=0.0)
+ group.add_argument('--remote-device', type=str, default='none', choices=['none', 'cpu', 'nvme'],
+ help='Remote device for ZeRO-3 initialized parameters.')
+ group.add_argument('--use-pin-memory', action='store_true',
+ help='Use pinned CPU memory for ZeRO-3 initialized model parameters.')
+ return parser
+
+def _add_memoryopt_args(parser):
+ """Memory optimization arguments."""
+
+ group = parser.add_argument_group('Memory optimizations', 'configurations')
+ group.add_argument("--scattered-embeddings", action='store_true',
+ help='Save memory by scattering embedding activations. '
+ 'Introduces dropout differences across MP configurations.')
+ group.add_argument("--split-transformers", action='store_true',
+ help='Save memory by splitting transformer layers into two parts, '
+ 'allowing for more frequent activation checkpoint savings.')
+ group.add_argument("--memory-centric-tiled-linear", action="store_true",
+ help='Save memory by tiling with deepspeed.zero.TiledLinear.')
+ group.add_argument("--tile-factor", type=int, default=1,
+ help='Make all linear layers the same size of [hidden/tile_factor, hidden/tile_factor]. '
+ 'Must be enabled with --memory-centric-tiled-linear. '
+ 'Example A: if tile_factor=1, the qkv layer [hidden, 3* hidden] would be converted into [1,3] tiles of size [hidden,hidden]. '
+ 'Example B: if tile_factor=2, the intermediate layer [4*hidden, hidden] will be converted into [8, 2] tiles of size [hidden/2, hidden/2]. '
+ 'Default is 1.')
+
+ return parser
+
+
+
+
+def _add_activation_checkpoint_args(parser):
+ group = parser.add_argument_group('Activation Checkpointing',
+ 'Checkpointing Configurations')
+ group.add_argument('--deepspeed-activation-checkpointing', action='store_true',
+ help='uses activation checkpointing from deepspeed')
+ group.add_argument('--partition-activations', action='store_true',
+ help='partition Activations across GPUs before checkpointing.')
+ group.add_argument('--contigious-checkpointing', action='store_true',
+ help='Contigious memory checkpointing for activatoins.')
+ group.add_argument('--checkpoint-in-cpu', action='store_true',
+ help='Move the activation checkpoints to CPU.')
+ group.add_argument('--synchronize-each-layer', action='store_true',
+ help='does a synchronize at the beginning and end of each checkpointed layer.')
+ group.add_argument('--profile-backward', action='store_true',
+ help='Enables backward pass profiling for checkpointed layers.')
+ return parser
diff --git a/megatron_lm/megatron/checkpointing.py b/megatron_lm/megatron/checkpointing.py
new file mode 100644
index 0000000..ffd95a6
--- /dev/null
+++ b/megatron_lm/megatron/checkpointing.py
@@ -0,0 +1,464 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Input/output checkpointing."""
+
+import os
+import random
+import re
+import sys
+import numpy as np
+
+import torch
+from torch.nn.parallel import DistributedDataParallel as torchDDP
+
+import torch.nn.functional as F
+
+from megatron import mpu, get_args
+from megatron import get_args
+from megatron import print_rank_0
+
+_CHECKPOINT_VERSION = None
+
+def set_checkpoint_version(value):
+ global _CHECKPOINT_VERSION
+ assert _CHECKPOINT_VERSION is None, \
+ "checkpoint version already set"
+ _CHECKPOINT_VERSION = value
+
+def get_checkpoint_version():
+ global _CHECKPOINT_VERSION
+ return _CHECKPOINT_VERSION
+
+def check_checkpoint_args(checkpoint_args):
+ """Ensure fixed arguments for a model are the same for the input
+ arguments and the one retreived frm checkpoint."""
+ args = get_args()
+
+ def _compare(arg_name):
+ checkpoint_value = getattr(checkpoint_args, arg_name)
+ args_value = getattr(args, arg_name)
+ error_message = '{} value from checkpoint ({}) is not equal to the ' \
+ 'input argument value ({}).'.format(
+ arg_name, checkpoint_value, args_value)
+ assert checkpoint_value == args_value, error_message
+
+ _compare('num_layers')
+ _compare('hidden_size')
+ _compare('num_attention_heads')
+ _compare('max_position_embeddings')
+ _compare('make_vocab_size_divisible_by')
+ _compare('padded_vocab_size')
+ _compare('tokenizer_type')
+ _compare('model_parallel_size')
+
+
+def ensure_directory_exists(filename):
+ """Build filename's path if it does not already exists."""
+ dirname = os.path.dirname(filename)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+
+def get_checkpoint_name(checkpoints_path, iteration,
+ release=False, mp_rank=None):
+ """A unified checkpoint name."""
+ if release:
+ directory = 'release'
+ else:
+ directory = 'iter_{:07d}'.format(iteration)
+ return os.path.join(checkpoints_path, directory,
+ 'mp_rank_{:02d}'.format(
+ mpu.get_model_parallel_rank() if mp_rank is None
+ else mp_rank),
+ 'model_optim_rng.pt')
+
+
+def get_checkpoint_tracker_filename(checkpoints_path):
+ """Tracker file rescords the latest chckpoint during
+ training to restart from."""
+ return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
+
+
+def save_ds_checkpoint(iteration, model, args):
+ """Save a model checkpoint."""
+
+ sd = {}
+ sd['iteration'] = iteration
+ sd['tokens'] = args.tokens
+ sd['checkpoint_version'] = 2.0
+ sd['args'] = args
+
+ # rng states.
+ if not args.no_save_rng:
+ sd['random_rng_state'] = random.getstate()
+ sd['np_rng_state'] = np.random.get_state()
+ sd['torch_rng_state'] = torch.get_rng_state()
+ sd['cuda_rng_state'] = torch.cuda.get_rng_state()
+ sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
+
+ #megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict
+ #state_dict is used by deepspeed for module saving so it needs to point to the right function
+ original_state_dict = model.module.state_dict
+ model.module.state_dict = model.module.state_dict_for_save_checkpoint
+ try:
+ model.save_checkpoint(args.save, client_state=sd)
+ finally:
+ model.module.state_dict = original_state_dict
+
+
+def save_checkpoint(iteration, model, optimizer, lr_scheduler):
+ """Save a model checkpoint."""
+ args = get_args()
+ # args.save = 'rewrite'
+
+ if args.deepspeed:
+ save_ds_checkpoint(iteration, model, args)
+ else:
+ # Only rank zero of the data parallel writes to the disk.
+ if isinstance(model, torchDDP):
+ model = model.module
+ if mpu.get_data_parallel_rank() == 0:
+
+ # Arguments, iteration, and model.
+ state_dict = {}
+ state_dict['args'] = args
+ state_dict['checkpoint_version'] = 2.0
+ state_dict['iteration'] = iteration
+ state_dict['tokens'] = args.tokens
+ state_dict['model'] = model.state_dict_for_save_checkpoint()
+
+ # Optimizer stuff.
+ if not args.no_save_optim:
+ if optimizer is not None:
+ state_dict['optimizer'] = optimizer.state_dict()
+ if lr_scheduler is not None:
+ state_dict['lr_scheduler'] = lr_scheduler.state_dict()
+
+ # RNG states.
+ if not args.no_save_rng:
+ state_dict['random_rng_state'] = random.getstate()
+ state_dict['np_rng_state'] = np.random.get_state()
+ state_dict['torch_rng_state'] = torch.get_rng_state()
+ state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
+ state_dict['rng_tracker_states'] \
+ = mpu.get_cuda_rng_tracker().get_states()
+
+ # Save.
+ checkpoint_name = get_checkpoint_name(args.save, iteration)
+ print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
+ format(torch.distributed.get_rank(), iteration,
+ checkpoint_name))
+ ensure_directory_exists(checkpoint_name)
+ torch.save(state_dict, checkpoint_name)
+ print(' successfully saved {}'.format(checkpoint_name))
+
+ # Wait so everyone is done (necessary)
+ torch.distributed.barrier()
+ # And update the latest iteration
+ if torch.distributed.get_rank() == 0:
+ tracker_filename = get_checkpoint_tracker_filename(args.save)
+ with open(tracker_filename, 'w') as f:
+ f.write(str(iteration))
+ # Wait so everyone is done (not necessary)
+ torch.distributed.barrier()
+
+
+def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
+ """Load a model checkpoint and return the iteration."""
+ args = get_args()
+ if args.load_release_checkpoint:
+ load_checkpoint_new(model, optimizer, lr_scheduler)
+ return 0
+
+ load_dir = getattr(args, load_arg)
+
+ if isinstance(model, torchDDP):
+ model = model.module
+ # Read the tracker file and set the iteration.
+ tracker_filename = get_checkpoint_tracker_filename(load_dir)
+
+ # If no tracker file, return iretation zero.
+ if not os.path.isfile(tracker_filename):
+ print_rank_0('WARNING: could not find the metadata file {} '.format(
+ tracker_filename))
+ print_rank_0(' will not load any checkpoints and will start from '
+ 'random')
+ return 0
+
+ # Otherwise, read the tracker file and either set the iteration or
+ # mark it as a release checkpoint.
+ iteration = 0
+ release = False
+ with open(tracker_filename, 'r') as f:
+ metastring = f.read().strip()
+ try:
+ iteration = int(metastring)
+ except ValueError:
+ release = metastring == 'release'
+ if not release:
+ print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
+ tracker_filename))
+ sys.exit()
+
+ assert iteration > 0 or release, 'error parsing metadata file {}'.format(
+ tracker_filename)
+
+ if args.deepspeed:
+ checkpoint_name, state_dict = model.load_checkpoint(load_dir)
+
+ if checkpoint_name is None:
+ if mpu.get_data_parallel_rank() == 0:
+ print("Unable to load checkpoint.")
+ return iteration
+
+ else:
+ # Checkpoint.
+ checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
+ if mpu.get_data_parallel_rank() == 0:
+ print('global rank {} is loading checkpoint {}'.format(
+ torch.distributed.get_rank(), checkpoint_name))
+
+ # Load the checkpoint.
+ try:
+ state_dict = torch.load(checkpoint_name, map_location='cpu')
+ except ModuleNotFoundError:
+ # For backward compatibility.
+ print_rank_0(' > deserializing using the old code structure ...')
+ sys.modules['fp16.loss_scaler'] = sys.modules[
+ 'megatron.fp16.loss_scaler']
+ state_dict = torch.load(checkpoint_name, map_location='cpu')
+ sys.modules.pop('fp16.loss_scaler', None)
+ except BaseException:
+ print_rank_0('could not load the checkpoint')
+ sys.exit()
+ # Model.
+
+ # print('>>>', model.state_dict().keys())
+ # print('<<<', state_dict['model'].keys())
+ if 'model' in state_dict:
+ model.load_state_dict(state_dict['model'])
+ else:
+ # This is a HACK to load deepspeed checkpoint's model state even if not initialized with deepspeed
+ model.load_state_dict(state_dict['module'])
+
+ # Optimizer.
+ if not release and not args.finetune and not args.no_load_optim:
+ try:
+ if optimizer is not None:
+ optimizer.load_state_dict(state_dict['optimizer'])
+ if lr_scheduler is not None:
+ lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
+ except KeyError:
+ print_rank_0(
+ 'Unable to load optimizer from checkpoint {}. '
+ 'Specify --no-load-optim or --finetune to prevent '
+ 'attempting to load the optimizer state, '
+ 'exiting ...'.format(checkpoint_name))
+ sys.exit()
+
+ # set checkpoint version
+ set_checkpoint_version(state_dict.get('checkpoint_version', 0))
+
+ # Set iteration.
+ if args.finetune or release:
+ iteration = 0
+ else:
+ try:
+ iteration = state_dict['iteration']
+ if 'tokens' in state_dict:
+ args.tokens = state_dict['tokens']
+ except KeyError:
+ try: # Backward compatible with older checkpoints
+ iteration = state_dict['total_iters']
+ except KeyError:
+ print_rank_0('A metadata file exists but unable to load '
+ 'iteration from checkpoint {}, exiting'.format(
+ checkpoint_name))
+ sys.exit()
+
+
+ # Check arguments.
+ if 'args' in state_dict:
+ checkpoint_args = state_dict['args']
+ check_checkpoint_args(checkpoint_args)
+ else:
+ print_rank_0('could not find arguments in the checkpoint ...')
+
+ # rng states.
+ if not release and not args.finetune and not args.no_load_rng:
+ try:
+ random.setstate(state_dict['random_rng_state'])
+ np.random.set_state(state_dict['np_rng_state'])
+ torch.set_rng_state(state_dict['torch_rng_state'])
+ torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
+ mpu.get_cuda_rng_tracker().set_states(
+ state_dict['rng_tracker_states'])
+ except KeyError:
+ print_rank_0('Unable to load optimizer from checkpoint {}. '
+ 'Specify --no-load-rng or --finetune to prevent '
+ 'attempting to load the optimizer state, '
+ 'exiting ...'.format(checkpoint_name))
+ sys.exit()
+
+ torch.distributed.barrier()
+ if mpu.get_data_parallel_rank() == 0:
+ print(' successfully loaded {}'.format(checkpoint_name))
+
+ return iteration
+
+
+def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False):
+ """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
+
+ args = get_args()
+
+ if isinstance(model, torchDDP):
+ model = model.module
+
+ load_path = args.load if from_realm_chkpt else args.ict_load
+
+ tracker_filename = get_checkpoint_tracker_filename(load_path)
+ with open(tracker_filename, 'r') as f:
+ iteration = int(f.read().strip())
+
+ # assert iteration > 0
+ checkpoint_name = get_checkpoint_name(load_path, iteration, False)
+ if mpu.get_data_parallel_rank() == 0:
+ print('global rank {} is loading checkpoint {}'.format(
+ torch.distributed.get_rank(), checkpoint_name))
+
+ state_dict = torch.load(checkpoint_name, map_location='cpu')
+ ict_state_dict = state_dict['model']
+ if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
+ print(" loading ICT state dict from REALM", flush=True)
+ ict_state_dict = ict_state_dict['retriever']['ict_model']
+
+ if only_query_model:
+ ict_state_dict.pop('context_model')
+ if only_block_model:
+ ict_state_dict.pop('question_model')
+
+ model.load_state_dict(ict_state_dict)
+ torch.distributed.barrier()
+
+ if mpu.get_data_parallel_rank() == 0:
+ print(' successfully loaded {}'.format(checkpoint_name))
+
+ return model
+
+
+def load_checkpoint_new(model, optimizer, lr_scheduler):
+ args = get_args()
+ load_dir = args.load
+ mp_rank = mpu.get_model_parallel_rank()
+ mp_size = args.model_parallel_size
+ if mp_size == 1:
+ mp_str = ''
+ else:
+ mp_str = f' [mp {mp_rank:02d} / {mp_size}]'
+
+ torch.distributed.barrier()
+ print_rank_0(f'> Start loading from release checkpoint from folder {load_dir}')
+
+ state_dict = model.state_dict()
+ is_loaded = dict.fromkeys(state_dict.keys(), False)
+
+ for name_pt in os.listdir(load_dir):
+ if not re.fullmatch(r'layer_\d\d-model_00-model_states\.pt', name_pt):
+ print(f'>> Found {name_pt}, skipping it')
+ continue
+
+ fname = os.path.join(load_dir, name_pt)
+ print(f'>> Loading {name_pt} on CPU{mp_str}')
+ part = torch.load(fname, map_location='cpu')
+
+ for key, weight in part.items():
+ key_converted = map_key(key, name_pt, args.num_layers)
+ if key_converted is None or key_converted not in state_dict:
+ print(f'>>> Skip {key} (converted: {key_converted}) which is not in state_dict{mp_str}')
+ continue
+
+ print(f'>>> Setting {key} to {key_converted}{mp_str}')
+ old_shape = weight.shape
+ weight = pad_weight_if_needed(key, weight, args._intermediate_pad)
+ if old_shape != weight.shape:
+ print(f'>>>> Pad {key} from {old_shape} to {weight.shape}')
+
+ tensor = state_dict[key_converted]
+ if weight.shape == tensor.shape:
+ tensor.copy_(weight)
+ else:
+ assert mp_size > 1, f"mp is 1, but loaded {key} shape: {weight.shape}, state_dict {key_converted}: {tensor.shape}"
+ assert len(weight.shape) == len(tensor.shape)
+ diff_dim = None
+ for i in range(len(weight.shape)):
+ if weight.shape[i] != tensor.shape[i]:
+ assert diff_dim is None, f"Loaded {key} shape: {weight.shape}, state_dict {key_converted}: {tensor.shape}"
+ diff_dim = i
+ num_partitions = mpu.divide(weight.shape[diff_dim], tensor.shape[diff_dim])
+ weights = split_into_partitions(weight, num_partitions, diff_dim)
+ weight_mp = weights[mp_rank]
+ assert weight_mp.shape == tensor.shape, f"Split didn't help on {key}: checkpoint is {weight.shape}, state_dict is {tensor.shape}, split is {weight_mp.shape}{mp_str}"
+ tensor.copy_(weight_mp)
+ is_loaded[key_converted] = True
+
+ for key, flag in is_loaded.items():
+ if not flag:
+ print(f'> !!! {key} has not been found in checkpoint{mp_str}')
+ torch.distributed.barrier()
+ print('> Finish loading from release checkpoint')
+
+
+def map_key(key, name_pt, num_layers):
+ '''Map state dict key from checkpoint to the current model'''
+ num_pt = int(name_pt[6:8])
+ if num_pt == 0:
+ return f'language_model.embedding.{key}'
+ if num_pt == 1:
+ return f'language_model.projector.{key}'
+ num_state = num_pt - 3
+ if 0 <= num_state < num_layers:
+ return f'language_model.transformer.layers.{num_state}.{key}'
+ if num_pt == num_layers + 4:
+ return f'language_model.output_layer.{key}'
+
+
+def split_into_partitions(tensor, num_partitions, partition_dim, stride=1):
+
+ per_partition_size = mpu.utils.divide(tensor.size(partition_dim), num_partitions)
+ per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
+
+ partitions_list = torch.split(tensor, per_partition_per_stride_size, dim=partition_dim)
+
+ partitions = []
+ for i in range(num_partitions):
+ partition = torch.cat(partitions_list[i::num_partitions], dim=partition_dim)
+ partitions.append(partition)
+
+ return partitions
+
+
+def pad_weight_if_needed(key, weight, intermediate_pad):
+ if intermediate_pad == 0:
+ return weight
+ if 'dense_ffn_gate' in key or 'dense_ffn_hidden' in key:
+ if 'weight' in key:
+ return F.pad(weight, (0, 0, 0, intermediate_pad))
+ else:
+ return F.pad(weight, (0, intermediate_pad))
+ if 'dense_ffn_output.weight' in key:
+ return F.pad(weight, (0, intermediate_pad))
+ return weight
diff --git a/megatron_lm/megatron/data/Makefile b/megatron_lm/megatron/data/Makefile
new file mode 100644
index 0000000..8f9db76
--- /dev/null
+++ b/megatron_lm/megatron/data/Makefile
@@ -0,0 +1,9 @@
+CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
+CPPFLAGS += $(shell python3 -m pybind11 --includes)
+LIBNAME = helpers
+LIBEXT = $(shell python3-config --extension-suffix)
+
+default: $(LIBNAME)$(LIBEXT)
+
+%$(LIBEXT): %.cpp
+ $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
diff --git a/megatron_lm/megatron/data/__init__.py b/megatron_lm/megatron/data/__init__.py
new file mode 100644
index 0000000..cd5f898
--- /dev/null
+++ b/megatron_lm/megatron/data/__init__.py
@@ -0,0 +1 @@
+from . import indexed_dataset
diff --git a/megatron_lm/megatron/data/bert_dataset.py b/megatron_lm/megatron/data/bert_dataset.py
new file mode 100644
index 0000000..5203666
--- /dev/null
+++ b/megatron_lm/megatron/data/bert_dataset.py
@@ -0,0 +1,232 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BERT Style dataset."""
+
+import os
+import time
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+
+from megatron import get_tokenizer, get_args
+from megatron import print_rank_0
+from megatron import mpu
+from megatron.data.dataset_utils import get_a_and_b_segments
+from megatron.data.dataset_utils import truncate_segments
+from megatron.data.dataset_utils import create_tokens_and_tokentypes
+from megatron.data.dataset_utils import pad_and_convert_to_numpy
+from megatron.data.dataset_utils import create_masked_lm_predictions
+
+
+class BertDataset(Dataset):
+
+ def __init__(self, name, indexed_dataset, data_prefix,
+ num_epochs, max_num_samples, masked_lm_prob,
+ max_seq_length, short_seq_prob, seed):
+
+ # Params to store.
+ self.name = name
+ self.seed = seed
+ self.masked_lm_prob = masked_lm_prob
+ self.max_seq_length = max_seq_length
+
+ # Dataset.
+ self.indexed_dataset = indexed_dataset
+
+ # Build the samples mapping.
+ self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
+ data_prefix,
+ num_epochs,
+ max_num_samples,
+ self.max_seq_length,
+ short_seq_prob,
+ self.seed,
+ self.name)
+
+ # Vocab stuff.
+ tokenizer = get_tokenizer()
+ self.vocab_id_list = list(tokenizer.inv_vocab.keys())
+ self.vocab_id_to_token_dict = tokenizer.inv_vocab
+ self.cls_id = tokenizer.cls
+ self.sep_id = tokenizer.sep
+ self.mask_id = tokenizer.mask
+ self.pad_id = tokenizer.pad
+
+ def __len__(self):
+ return self.samples_mapping.shape[0]
+
+ def __getitem__(self, idx):
+ start_idx, end_idx, seq_length = self.samples_mapping[idx]
+ sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
+ # Note that this rng state should be numpy and not python since
+ # python randint is inclusive whereas the numpy one is exclusive.
+ np_rng = np.random.RandomState(seed=(self.seed + idx))
+ return build_training_sample(sample, seq_length,
+ self.max_seq_length, # needed for padding
+ self.vocab_id_list,
+ self.vocab_id_to_token_dict,
+ self.cls_id, self.sep_id,
+ self.mask_id, self.pad_id,
+ self.masked_lm_prob, np_rng)
+
+
+def get_samples_mapping_(indexed_dataset,
+ data_prefix,
+ num_epochs,
+ max_num_samples,
+ max_seq_length,
+ short_seq_prob,
+ seed,
+ name):
+ if not num_epochs:
+ if not max_num_samples:
+ raise ValueError("Need to specify either max_num_samples "
+ "or num_epochs")
+ num_epochs = np.iinfo(np.int32).max - 1
+ if not max_num_samples:
+ max_num_samples = np.iinfo(np.int64).max - 1
+
+ # Filename of the index mapping
+ indexmap_filename = data_prefix
+ indexmap_filename += '_{}_indexmap'.format(name)
+ if num_epochs != (np.iinfo(np.int32).max - 1):
+ indexmap_filename += '_{}ep'.format(num_epochs)
+ if max_num_samples != (np.iinfo(np.int64).max - 1):
+ indexmap_filename += '_{}mns'.format(max_num_samples)
+ indexmap_filename += '_{}msl'.format(max_seq_length)
+ indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
+ indexmap_filename += '_{}s'.format(seed)
+ indexmap_filename += '.npy'
+
+ # Build the indexed mapping if not exist.
+ if torch.distributed.get_rank() == 0 and \
+ not os.path.isfile(indexmap_filename):
+ print(' > WARNING: could not find index map file {}, building '
+ 'the indices on rank 0 ...'.format(indexmap_filename))
+
+ # Make sure the types match the helpers input types.
+ assert indexed_dataset.doc_idx.dtype == np.int64
+ assert indexed_dataset.sizes.dtype == np.int32
+
+ # Build samples mapping
+ verbose = torch.distributed.get_rank() == 0
+ start_time = time.time()
+ print_rank_0(' > building sapmles index mapping for {} ...'.format(
+ name))
+ # First compile and then import.
+ from megatron.data.dataset_utils import compile_helper
+ compile_helper()
+ from megatron.data import helpers
+ samples_mapping = helpers.build_mapping(
+ indexed_dataset.doc_idx,
+ indexed_dataset.sizes,
+ num_epochs,
+ max_num_samples,
+ max_seq_length - 3, # account for added tokens
+ short_seq_prob,
+ seed,
+ verbose)
+ print_rank_0(' > done building sapmles index maping')
+ np.save(indexmap_filename, samples_mapping, allow_pickle=True)
+ print_rank_0(' > saved the index mapping in {}'.format(
+ indexmap_filename))
+ # Make sure all the ranks have built the mapping
+ print_rank_0(' > elasped time to build and save samples mapping '
+ '(seconds): {:4f}'.format(
+ time.time() - start_time))
+ # This should be a barrier but nccl barrier assumes
+ # device_index=rank which is not the case for model
+ # parallel case
+ counts = torch.cuda.LongTensor([1])
+ torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
+ assert counts[0].item() == torch.distributed.get_world_size(
+ group=mpu.get_data_parallel_group())
+
+ # Load indexed dataset.
+ print_rank_0(' > loading indexed mapping from {}'.format(
+ indexmap_filename))
+ start_time = time.time()
+ samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
+ print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
+ time.time() - start_time))
+ print_rank_0(' total number of samples: {}'.format(
+ samples_mapping.shape[0]))
+
+ return samples_mapping
+
+
+def build_training_sample(sample,
+ target_seq_length, max_seq_length,
+ vocab_id_list, vocab_id_to_token_dict,
+ cls_id, sep_id, mask_id, pad_id,
+ masked_lm_prob, np_rng):
+ """Biuld training sample.
+
+ Arguments:
+ sample: A list of sentences in which each sentence is a list token ids.
+ target_seq_length: Desired sequence length.
+ max_seq_length: Maximum length of the sequence. All values are padded to
+ this length.
+ vocab_id_list: List of vocabulary ids. Used to pick a random id.
+ vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
+ cls_id: Start of example id.
+ sep_id: Separator id.
+ mask_id: Mask token id.
+ pad_id: Padding token id.
+ masked_lm_prob: Probability to mask tokens.
+ np_rng: Random number genenrator. Note that this rng state should be
+ numpy and not python since python randint is inclusive for
+ the opper bound whereas the numpy one is exclusive.
+ """
+
+ # We assume that we have at least two sentences in the sample
+ assert len(sample) > 1
+ assert target_seq_length <= max_seq_length
+
+ # Divide sample into two segments (A and B).
+ tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
+
+ # Truncate to `target_sequence_length`.
+ max_num_tokens = target_seq_length
+ truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
+ len(tokens_b), max_num_tokens, np_rng)
+
+ # Build tokens and toketypes.
+ tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
+ cls_id, sep_id)
+
+ # Masking.
+ max_predictions_per_seq = masked_lm_prob * max_num_tokens
+ (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
+ tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
+ cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
+
+ # Padding.
+ tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
+ = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
+ masked_labels, pad_id, max_seq_length)
+
+ train_sample = {
+ 'text': tokens_np,
+ 'types': tokentypes_np,
+ 'labels': labels_np,
+ 'is_random': int(is_next_random),
+ 'loss_mask': loss_mask_np,
+ 'padding_mask': padding_mask_np,
+ 'truncated': int(truncated)}
+ return train_sample
+
diff --git a/megatron_lm/megatron/data/dataset_utils.py b/megatron_lm/megatron/data/dataset_utils.py
new file mode 100644
index 0000000..d51b1ce
--- /dev/null
+++ b/megatron_lm/megatron/data/dataset_utils.py
@@ -0,0 +1,503 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Most of the code here has been copied from:
+# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
+# with some modifications.
+
+import time
+import collections
+
+import numpy as np
+from megatron import get_args, print_rank_0
+from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
+
+DSET_TYPE_STD = 'standard_bert'
+DSET_TYPE_ICT = 'ict'
+
+DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
+
+
+def compile_helper():
+ """Compile helper function ar runtime. Make sure this
+ is invoked on a single process."""
+ import os
+ import subprocess
+ path = os.path.abspath(os.path.dirname(__file__))
+ ret = subprocess.run(['make', '-C', path])
+ if ret.returncode != 0:
+ print("Making C++ dataset helpers module failed, exiting.")
+ import sys
+ sys.exit(1)
+
+
+def get_a_and_b_segments(sample, np_rng):
+ """Divide sample into a and b segments."""
+
+ # Number of sentences in the sample.
+ n_sentences = len(sample)
+ # Make sure we always have two sentences.
+ assert n_sentences > 1, 'make sure each sample has at least two sentences.'
+
+ # First part:
+ # `a_end` is how many sentences go into the `A`.
+ a_end = 1
+ if n_sentences >= 3:
+ # Note that randin in numpy is exclusive.
+ a_end = np_rng.randint(1, n_sentences)
+ tokens_a = []
+ for j in range(a_end):
+ tokens_a.extend(sample[j])
+
+ # Second part:
+ tokens_b = []
+ for j in range(a_end, n_sentences):
+ tokens_b.extend(sample[j])
+
+ # Random next:
+ is_next_random = False
+ if np_rng.random() < 0.5:
+ is_next_random = True
+ tokens_a, tokens_b = tokens_b, tokens_a
+
+ return tokens_a, tokens_b, is_next_random
+
+
+def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
+ """Truncates a pair of sequences to a maximum sequence length."""
+ #print(len_a, len_b, max_num_tokens)
+ assert len_a > 0
+ assert len_b > 0
+ if len_a + len_b <= max_num_tokens:
+ return False
+ while len_a + len_b > max_num_tokens:
+ if len_a > len_b:
+ len_a -= 1
+ tokens = tokens_a
+ else:
+ len_b -= 1
+ tokens = tokens_b
+ if np_rng.random() < 0.5:
+ del tokens[0]
+ else:
+ tokens.pop()
+ return True
+
+
+def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
+ """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
+
+ tokens = []
+ tokentypes = []
+ # [CLS].
+ tokens.append(cls_id)
+ tokentypes.append(0)
+ # Segment A.
+ for token in tokens_a:
+ tokens.append(token)
+ tokentypes.append(0)
+ # [SEP].
+ tokens.append(sep_id)
+ tokentypes.append(0)
+ # Segment B.
+ for token in tokens_b:
+ tokens.append(token)
+ tokentypes.append(1)
+ # [SEP].
+ tokens.append(sep_id)
+ tokentypes.append(1)
+
+ return tokens, tokentypes
+
+
+MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
+ ["index", "label"])
+
+
+def is_start_piece(piece):
+ """Check if the current word piece is the starting piece (BERT)."""
+ # When a word has been split into
+ # WordPieces, the first token does not have any marker and any subsequence
+ # tokens are prefixed with ##. So whenever we see the ## token, we
+ # append it to the previous set of word indexes.
+ return not piece.startswith("##")
+
+
+def create_masked_lm_predictions(tokens,
+ vocab_id_list, vocab_id_to_token_dict,
+ masked_lm_prob,
+ cls_id, sep_id, mask_id,
+ max_predictions_per_seq,
+ np_rng,
+ max_ngrams=3,
+ do_whole_word_mask=True,
+ favor_longer_ngram=False,
+ do_permutation=False):
+ """Creates the predictions for the masked LM objective.
+ Note: Tokens here are vocab ids and not text tokens."""
+
+ cand_indexes = []
+ # Note(mingdachen): We create a list for recording if the piece is
+ # the starting piece of current token, where 1 means true, so that
+ # on-the-fly whole word masking is possible.
+ token_boundary = [0] * len(tokens)
+
+ for (i, token) in enumerate(tokens):
+ if token == cls_id or token == sep_id:
+ token_boundary[i] = 1
+ continue
+ # Whole Word Masking means that if we mask all of the wordpieces
+ # corresponding to an original word.
+ #
+ # Note that Whole Word Masking does *not* change the training code
+ # at all -- we still predict each WordPiece independently, softmaxed
+ # over the entire vocabulary.
+ if (do_whole_word_mask and len(cand_indexes) >= 1 and
+ not is_start_piece(vocab_id_to_token_dict[token])):
+ cand_indexes[-1].append(i)
+ else:
+ cand_indexes.append([i])
+ if is_start_piece(vocab_id_to_token_dict[token]):
+ token_boundary[i] = 1
+
+ output_tokens = list(tokens)
+
+ masked_lm_positions = []
+ masked_lm_labels = []
+
+ if masked_lm_prob == 0:
+ return (output_tokens, masked_lm_positions,
+ masked_lm_labels, token_boundary)
+
+ num_to_predict = min(max_predictions_per_seq,
+ max(1, int(round(len(tokens) * masked_lm_prob))))
+
+ # Note(mingdachen):
+ # By default, we set the probilities to favor shorter ngram sequences.
+ ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
+ pvals = 1. / np.arange(1, max_ngrams + 1)
+ pvals /= pvals.sum(keepdims=True)
+
+ if favor_longer_ngram:
+ pvals = pvals[::-1]
+
+ ngram_indexes = []
+ for idx in range(len(cand_indexes)):
+ ngram_index = []
+ for n in ngrams:
+ ngram_index.append(cand_indexes[idx:idx + n])
+ ngram_indexes.append(ngram_index)
+
+ np_rng.shuffle(ngram_indexes)
+
+ masked_lms = []
+ covered_indexes = set()
+ for cand_index_set in ngram_indexes:
+ if len(masked_lms) >= num_to_predict:
+ break
+ if not cand_index_set:
+ continue
+ # Note(mingdachen):
+ # Skip current piece if they are covered in lm masking or previous ngrams.
+ for index_set in cand_index_set[0]:
+ for index in index_set:
+ if index in covered_indexes:
+ continue
+
+ n = np_rng.choice(ngrams[:len(cand_index_set)],
+ p=pvals[:len(cand_index_set)] /
+ pvals[:len(cand_index_set)].sum(keepdims=True))
+ index_set = sum(cand_index_set[n - 1], [])
+ n -= 1
+ # Note(mingdachen):
+ # Repeatedly looking for a candidate that does not exceed the
+ # maximum number of predictions by trying shorter ngrams.
+ while len(masked_lms) + len(index_set) > num_to_predict:
+ if n == 0:
+ break
+ index_set = sum(cand_index_set[n - 1], [])
+ n -= 1
+ # If adding a whole-word mask would exceed the maximum number of
+ # predictions, then just skip this candidate.
+ if len(masked_lms) + len(index_set) > num_to_predict:
+ continue
+ is_any_index_covered = False
+ for index in index_set:
+ if index in covered_indexes:
+ is_any_index_covered = True
+ break
+ if is_any_index_covered:
+ continue
+ for index in index_set:
+ covered_indexes.add(index)
+
+ masked_token = None
+ # 80% of the time, replace with [MASK]
+ if np_rng.random() < 0.8:
+ masked_token = mask_id
+ else:
+ # 10% of the time, keep original
+ if np_rng.random() < 0.5:
+ masked_token = tokens[index]
+ # 10% of the time, replace with random word
+ else:
+ masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
+
+ output_tokens[index] = masked_token
+
+ masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
+ assert len(masked_lms) <= num_to_predict
+
+ np_rng.shuffle(ngram_indexes)
+
+ select_indexes = set()
+ if do_permutation:
+ for cand_index_set in ngram_indexes:
+ if len(select_indexes) >= num_to_predict:
+ break
+ if not cand_index_set:
+ continue
+ # Note(mingdachen):
+ # Skip current piece if they are covered in lm masking or previous ngrams.
+ for index_set in cand_index_set[0]:
+ for index in index_set:
+ if index in covered_indexes or index in select_indexes:
+ continue
+
+ n = np.random.choice(ngrams[:len(cand_index_set)],
+ p=pvals[:len(cand_index_set)] /
+ pvals[:len(cand_index_set)].sum(keepdims=True))
+ index_set = sum(cand_index_set[n - 1], [])
+ n -= 1
+
+ while len(select_indexes) + len(index_set) > num_to_predict:
+ if n == 0:
+ break
+ index_set = sum(cand_index_set[n - 1], [])
+ n -= 1
+ # If adding a whole-word mask would exceed the maximum number of
+ # predictions, then just skip this candidate.
+ if len(select_indexes) + len(index_set) > num_to_predict:
+ continue
+ is_any_index_covered = False
+ for index in index_set:
+ if index in covered_indexes or index in select_indexes:
+ is_any_index_covered = True
+ break
+ if is_any_index_covered:
+ continue
+ for index in index_set:
+ select_indexes.add(index)
+ assert len(select_indexes) <= num_to_predict
+
+ select_indexes = sorted(select_indexes)
+ permute_indexes = list(select_indexes)
+ np_rng.shuffle(permute_indexes)
+ orig_token = list(output_tokens)
+
+ for src_i, tgt_i in zip(select_indexes, permute_indexes):
+ output_tokens[src_i] = orig_token[tgt_i]
+ masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
+
+ masked_lms = sorted(masked_lms, key=lambda x: x.index)
+
+ for p in masked_lms:
+ masked_lm_positions.append(p.index)
+ masked_lm_labels.append(p.label)
+
+ return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
+
+
+def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
+ masked_labels, pad_id, max_seq_length):
+ """Pad sequences and convert them to numpy."""
+
+ # Some checks.
+ num_tokens = len(tokens)
+ padding_length = max_seq_length - num_tokens
+ assert padding_length >= 0
+ assert len(tokentypes) == num_tokens
+ assert len(masked_positions) == len(masked_labels)
+
+ # Tokens and token types.
+ filler = [pad_id] * padding_length
+ tokens_np = np.array(tokens + filler, dtype=np.int64)
+ tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
+
+ # Padding mask.
+ padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
+ dtype=np.int64)
+
+ # Lables and loss mask.
+ labels = [-1] * max_seq_length
+ loss_mask = [0] * max_seq_length
+ for i in range(len(masked_positions)):
+ assert masked_positions[i] < num_tokens
+ labels[masked_positions[i]] = masked_labels[i]
+ loss_mask[masked_positions[i]] = 1
+ labels_np = np.array(labels, dtype=np.int64)
+ loss_mask_np = np.array(loss_mask, dtype=np.int64)
+
+ return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
+
+
+def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
+ train_valid_test_num_samples,
+ max_seq_length, masked_lm_prob,
+ short_seq_prob, seed, skip_warmup,
+ dataset_type='standard_bert'):
+
+ if dataset_type not in DSET_TYPES:
+ raise ValueError("Invalid dataset_type: ", dataset_type)
+
+ # Indexed dataset.
+ indexed_dataset = get_indexed_dataset_(data_prefix,
+ data_impl,
+ skip_warmup)
+
+ if dataset_type == DSET_TYPE_ICT:
+ args = get_args()
+ title_dataset = get_indexed_dataset_(args.titles_data_path,
+ data_impl,
+ skip_warmup)
+
+ # Get start and end indices of train/valid/train into doc-idx
+ # Note that doc-idx is desinged to be num-docs + 1 so we can
+ # easily iterate over it.
+ total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
+ splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
+
+ # Print stats about the splits.
+ print_rank_0(' > dataset split:')
+
+ def print_split_stats(name, index):
+ print_rank_0(' {}:'.format(name))
+ print_rank_0(' document indices in [{}, {}) total of {} '
+ 'documents'.format(splits[index], splits[index + 1],
+ splits[index + 1] - splits[index]))
+ start_index = indexed_dataset.doc_idx[splits[index]]
+ end_index = indexed_dataset.doc_idx[splits[index + 1]]
+ print_rank_0(' sentence indices in [{}, {}) total of {} '
+ 'sentences'.format(start_index, end_index,
+ end_index - start_index))
+ print_split_stats('train', 0)
+ print_split_stats('validation', 1)
+ print_split_stats('test', 2)
+
+ def build_dataset(index, name):
+ from megatron.data.bert_dataset import BertDataset
+ from megatron.data.ict_dataset import ICTDataset
+ dataset = None
+ if splits[index + 1] > splits[index]:
+ # Get the pointer to the original doc-idx so we can set it later.
+ doc_idx_ptr = indexed_dataset.get_doc_idx()
+ # Slice the doc-idx
+ start_index = splits[index]
+ # Add +1 so we can index into the dataset to get the upper bound.
+ end_index = splits[index + 1] + 1
+ # New doc_idx view.
+ indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
+ # Build the dataset accordingly.
+ kwargs = dict(
+ name=name,
+ data_prefix=data_prefix,
+ num_epochs=None,
+ max_num_samples=train_valid_test_num_samples[index],
+ max_seq_length=max_seq_length,
+ seed=seed
+ )
+
+ if dataset_type == DSET_TYPE_ICT:
+ args = get_args()
+ dataset = ICTDataset(
+ block_dataset=indexed_dataset,
+ title_dataset=title_dataset,
+ query_in_block_prob=args.query_in_block_prob,
+ use_one_sent_docs=args.use_one_sent_docs,
+ **kwargs
+ )
+ else:
+ dataset = BertDataset(
+ indexed_dataset=indexed_dataset,
+ masked_lm_prob=masked_lm_prob,
+ short_seq_prob=short_seq_prob,
+ **kwargs
+ )
+
+ # Set the original pointer so dataset remains the main dataset.
+ indexed_dataset.set_doc_idx(doc_idx_ptr)
+ # Checks.
+ assert indexed_dataset.doc_idx[0] == 0
+ assert indexed_dataset.doc_idx.shape[0] == \
+ (total_num_of_documents + 1)
+ return dataset
+
+ train_dataset = build_dataset(0, 'train')
+ valid_dataset = build_dataset(1, 'valid')
+ test_dataset = build_dataset(2, 'test')
+
+ return (train_dataset, valid_dataset, test_dataset)
+
+
+def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
+
+ print_rank_0(' > building dataset index ...')
+
+ start_time = time.time()
+ indexed_dataset = make_indexed_dataset(data_prefix,
+ data_impl,
+ skip_warmup)
+ assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
+ print_rank_0(' > finished creating indexed dataset in {:4f} '
+ 'seconds'.format(time.time() - start_time))
+
+ print_rank_0(' > indexed dataset stats:')
+ print_rank_0(' number of documents: {}'.format(
+ indexed_dataset.doc_idx.shape[0] - 1))
+ print_rank_0(' number of sentences: {}'.format(
+ indexed_dataset.sizes.shape[0]))
+
+ return indexed_dataset
+
+
+def get_train_valid_test_split_(splits_string, size):
+ """ Get dataset splits from comma or '/' separated string list."""
+
+ splits = []
+ if splits_string.find(',') != -1:
+ splits = [float(s) for s in splits_string.split(',')]
+ elif splits_string.find('/') != -1:
+ splits = [float(s) for s in splits_string.split('/')]
+ else:
+ splits = [float(splits_string)]
+ while len(splits) < 3:
+ splits.append(0.)
+ splits = splits[:3]
+ splits_sum = sum(splits)
+ assert splits_sum > 0.0
+ splits = [split / splits_sum for split in splits]
+ splits_index = [0]
+ for index, split in enumerate(splits):
+ splits_index.append(splits_index[index] +
+ int(round(split * float(size))))
+ diff = splits_index[-1] - size
+ for index in range(1, len(splits_index)):
+ splits_index[index] -= diff
+ assert len(splits_index) == 4
+ assert splits_index[-1] == size
+ return splits_index
+
+
diff --git a/megatron_lm/megatron/data/gpt2_dataset.py b/megatron_lm/megatron/data/gpt2_dataset.py
new file mode 100644
index 0000000..f630a3c
--- /dev/null
+++ b/megatron_lm/megatron/data/gpt2_dataset.py
@@ -0,0 +1,317 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""GPT2 style dataset."""
+
+import os
+import time
+
+import numpy as np
+import torch
+
+from megatron import mpu, print_rank_0
+from megatron.data.dataset_utils import get_train_valid_test_split_
+from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
+
+
+def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
+ train_valid_test_num_samples,
+ seq_length, seed, skip_warmup):
+ """Build train, valid, and test datasets."""
+
+ # Indexed dataset.
+ indexed_dataset = get_indexed_dataset_(data_prefix,
+ data_impl,
+ skip_warmup)
+
+ total_num_of_documents = indexed_dataset.sizes.shape[0]
+ splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
+
+ # Print stats about the splits.
+ print_rank_0(' > dataset split:')
+
+ def print_split_stats(name, index):
+ print_rank_0(' {}:'.format(name))
+ print_rank_0(' document indices in [{}, {}) total of {} '
+ 'documents'.format(splits[index], splits[index + 1],
+ splits[index + 1] - splits[index]))
+ print_split_stats('train', 0)
+ print_split_stats('validation', 1)
+ print_split_stats('test', 2)
+
+ def build_dataset(index, name):
+ dataset = None
+ if splits[index + 1] > splits[index]:
+ documents = np.arange(start=splits[index], stop=splits[index + 1],
+ step=1, dtype=np.int32)
+ dataset = GPT2Dataset(name, data_prefix,
+ documents, indexed_dataset,
+ train_valid_test_num_samples[index],
+ seq_length, seed)
+ return dataset
+
+ train_dataset = build_dataset(0, 'train')
+ valid_dataset = build_dataset(1, 'valid')
+ test_dataset = build_dataset(2, 'test')
+
+ return (train_dataset, valid_dataset, test_dataset)
+
+
+def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
+ """Build indexed dataset."""
+ print_rank_0(' > building dataset index ...')
+
+ start_time = time.time()
+ indexed_dataset = make_indexed_dataset(data_prefix,
+ data_impl,
+ skip_warmup)
+ print_rank_0(' > finished creating indexed dataset in {:4f} '
+ 'seconds'.format(time.time() - start_time))
+ print_rank_0(' number of documents: {}'.format(
+ indexed_dataset.sizes.shape[0]))
+
+ return indexed_dataset
+
+
+class GPT2Dataset(torch.utils.data.Dataset):
+
+ def __init__(self, name, data_prefix, documents, indexed_dataset,
+ num_samples, seq_length, seed):
+
+ self.name = name
+ self.indexed_dataset = indexed_dataset
+
+ # Checks
+ assert np.min(documents) >= 0
+ assert np.max(documents) < indexed_dataset.sizes.shape[0]
+
+ # Build index mappings.
+ self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
+ self.name, data_prefix, documents, self.indexed_dataset.sizes,
+ num_samples, seq_length, seed)
+
+ def __len__(self):
+ # -1 is due to data structure used to retieve the index:
+ # sample i --> [sample_idx[i], sample_idx[i+1])
+ return self.sample_idx.shape[0] - 1
+
+ def __getitem__(self, idx):
+ # Get the shuffled index.
+ idx = self.shuffle_idx[idx]
+ # Start and end documents and offsets.
+ doc_index_f = self.sample_idx[idx][0]
+ doc_index_l = self.sample_idx[idx + 1][0]
+ offset_f = self.sample_idx[idx][1]
+ offset_l = self.sample_idx[idx + 1][1]
+ # If we are within the same document, just extract the chunk.
+ if doc_index_f == doc_index_l:
+ sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
+ offset=offset_f,
+ length=offset_l - offset_f + 1)
+ else:
+ # Otherwise, get the rest of the initial document.
+ sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
+ offset=offset_f)]
+ # Loop over all in between documents and add the entire document.
+ for i in range(doc_index_f + 1, doc_index_l):
+ sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
+ # And finally add the relevant portion of last document.
+ sample_list.append(self.indexed_dataset.get(
+ self.doc_idx[doc_index_l],
+ length=offset_l + 1))
+ sample = np.concatenate(sample_list)
+
+ return {'text': np.array(sample, dtype=np.int64)}
+
+
+def _build_index_mappings(name, data_prefix, documents, sizes,
+ num_samples, seq_length, seed):
+ """Build doc-idx, sample-idx, and shuffle-idx.
+ doc-idx: is an array (ordered) of documents to be used in training.
+ sample-idx: is the start document index and document offset for each
+ training sample.
+ shuffle-idx: maps the sample index into a random index into sample-idx.
+ """
+ # Number of tokens in each epoch and number of required epochs.
+ tokens_per_epoch = _num_tokens(documents, sizes)
+ num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
+ # rng state
+ np_rng = np.random.RandomState(seed=seed)
+
+ # Filename of the index mappings.
+ _filename = data_prefix
+ _filename += '_{}_indexmap'.format(name)
+ _filename += '_{}ns'.format(num_samples)
+ _filename += '_{}sl'.format(seq_length)
+ _filename += '_{}s'.format(seed)
+ doc_idx_filename = _filename + '_doc_idx.npy'
+ sample_idx_filename = _filename + '_sample_idx.npy'
+ shuffle_idx_filename = _filename + '_shuffle_idx.npy'
+
+ # Build the indexed mapping if not exist.
+ if torch.distributed.get_rank() == 0:
+ if (not os.path.isfile(doc_idx_filename)) or \
+ (not os.path.isfile(sample_idx_filename)) or \
+ (not os.path.isfile(shuffle_idx_filename)):
+
+ print_rank_0(' > WARNING: could not find index map files, building '
+ 'the indices on rank 0 ...')
+ # doc-idx.
+ start_time = time.time()
+ doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
+ np.save(doc_idx_filename, doc_idx, allow_pickle=True)
+ print_rank_0(' > elasped time to build and save doc-idx mapping '
+ '(seconds): {:4f}'.format(time.time() - start_time))
+ # sample-idx.
+ start_time = time.time()
+ # Use C++ implementation for speed.
+ # First compile and then import.
+ from megatron.data.dataset_utils import compile_helper
+ compile_helper()
+ from megatron.data import helpers
+ assert doc_idx.dtype == np.int32
+ assert sizes.dtype == np.int32
+ sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
+ num_epochs, tokens_per_epoch)
+ # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
+ # num_epochs, tokens_per_epoch)
+ np.save(sample_idx_filename, sample_idx, allow_pickle=True)
+ print_rank_0(' > elasped time to build and save sample-idx mapping '
+ '(seconds): {:4f}'.format(time.time() - start_time))
+ # shuffle-idx.
+ start_time = time.time()
+ # -1 is due to data structure used to retieve the index:
+ # sample i --> [sample_idx[i], sample_idx[i+1])
+ shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng)
+ np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
+ print_rank_0(' > elasped time to build and save shuffle-idx mapping'
+ ' (seconds): {:4f}'.format(time.time() - start_time))
+
+ # This should be a barrier but nccl barrier assumes
+ # device_index=rank which is not the case for model
+ # parallel case
+ counts = torch.cuda.LongTensor([1])
+ torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
+ assert counts[0].item() == torch.distributed.get_world_size(
+ group=mpu.get_data_parallel_group())
+
+ # Load mappings.
+ start_time = time.time()
+ print_rank_0(' > loading doc-idx mapping from {}'.format(
+ doc_idx_filename))
+ doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
+ print_rank_0(' > loading sample-idx mapping from {}'.format(
+ sample_idx_filename))
+ sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
+ print_rank_0(' > loading shuffle-idx mapping from {}'.format(
+ shuffle_idx_filename))
+ shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
+ print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
+ time.time() - start_time))
+ print_rank_0(' total number of samples: {}'.format(
+ sample_idx.shape[0]))
+ print_rank_0(' total number of epochs: {}'.format(num_epochs))
+
+ return doc_idx, sample_idx, shuffle_idx
+
+
+def _num_tokens(documents, sizes):
+ """Total number of tokens in the dataset."""
+ return np.sum(sizes[documents])
+
+
+def _num_epochs(tokens_per_epoch, seq_length, num_samples):
+ """Based on number of samples and sequence lenght, calculate how many
+ epochs will be needed."""
+ num_epochs = 0
+ total_tokens = 0
+ while True:
+ num_epochs += 1
+ total_tokens += tokens_per_epoch
+ # -1 is because we need to retrieve seq_length + 1 token each time
+ # but the last token will overlap with the first token of the next
+ # sample except for the last sample.
+ if ((total_tokens - 1) // seq_length) >= num_samples:
+ return num_epochs
+
+
+def _build_doc_idx(documents, num_epochs, np_rng):
+ """Build an array with length = number-of-epochs * number-of-dcuments.
+ Each index is mapped to a corresponding document."""
+ doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
+ doc_idx[:] = documents
+ doc_idx = doc_idx.reshape(-1)
+ doc_idx = doc_idx.astype(np.int32)
+ np_rng.shuffle(doc_idx)
+ return doc_idx
+
+
+def _build_sample_idx(sizes, doc_idx, seq_length,
+ num_epochs, tokens_per_epoch):
+ """Sample index mapping is a 2D array with sizes
+ [number-of-samples + 1, 2] where [..., 0] contains
+ the index into `doc_idx` and [..., 1] is the
+ starting offset in that document."""
+
+ # Total number of samples. For -1 see comments in `_num_epochs`.
+ num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
+ sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
+
+ # Index into sample_idx.
+ sample_index = 0
+ # Index into doc_idx.
+ doc_idx_index = 0
+ # Begining offset for each document.
+ doc_offset = 0
+ # Start with first document and no offset.
+ sample_idx[sample_index][0] = doc_idx_index
+ sample_idx[sample_index][1] = doc_offset
+ sample_index += 1
+ while sample_index <= num_samples:
+ # Start with a fresh sequence.
+ remaining_seq_length = seq_length + 1
+ while remaining_seq_length != 0:
+ # Get the document length.
+ doc_id = doc_idx[doc_idx_index]
+ doc_length = sizes[doc_id] - doc_offset
+ # And add it to the current sequence.
+ remaining_seq_length -= doc_length
+ # If we have more than a full sequence, adjust offset and set
+ # remaining length to zero so we return from the while loop.
+ # Note that -1 here is for the same reason we have -1 in
+ # `_num_epochs` calculations.
+ if remaining_seq_length <= 0:
+ doc_offset += (remaining_seq_length + doc_length - 1)
+ remaining_seq_length = 0
+ else:
+ # Otherwise, start from the begining of the next document.
+ doc_idx_index += 1
+ doc_offset = 0
+ # Record the sequence.
+ sample_idx[sample_index][0] = doc_idx_index
+ sample_idx[sample_index][1] = doc_offset
+ sample_index += 1
+
+ return sample_idx
+
+
+def _build_shuffle_idx(size, np_rng):
+ """Build the range [0, size) and shuffle."""
+ dtype_ = np.uint32
+ if size >= (np.iinfo(np.uint32).max - 1):
+ dtype_ = np.int64
+ shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
+ np_rng.shuffle(shuffle_idx)
+ return shuffle_idx
diff --git a/megatron_lm/megatron/data/helpers.cpp b/megatron_lm/megatron/data/helpers.cpp
new file mode 100644
index 0000000..ca90329
--- /dev/null
+++ b/megatron_lm/megatron/data/helpers.cpp
@@ -0,0 +1,643 @@
+/*
+ coding=utf-8
+ Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+
+/* Helper methods for fast index mapping builds */
+
+#include <algorithm>
+#include <iostream>
+#include <limits>
+#include <math.h>
+#include <stdexcept>
+#include <pybind11/pybind11.h>
+#include <pybind11/numpy.h>
+#include <random>
+
+namespace py = pybind11;
+using namespace std;
+
+const int32_t LONG_SENTENCE_LEN = 512;
+
+
+py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
+ const py::array_t<int32_t>& doc_idx_,
+ const int32_t seq_length,
+ const int32_t num_epochs,
+ const int64_t tokens_per_epoch) {
+ /* Sample index (sample_idx) is used for gpt2 like dataset for which
+ the documents are flattened and the samples are built based on this
+ 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
+ where [..., 0] contains the index into `doc_idx` and [..., 1] is the
+ starting offset in that document.*/
+
+ // Consistency checks.
+ assert(seq_length > 1);
+ assert(num_epochs > 0);
+ assert(tokens_per_epoch > 1);
+
+ // Remove bound checks.
+ auto sizes = sizes_.unchecked<1>();
+ auto doc_idx = doc_idx_.unchecked<1>();
+
+ // Mapping and it's length (1D).
+ int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
+ int32_t* sample_idx = new int32_t[2*(num_samples+1)];
+
+ cout << " using:" << endl << std::flush;
+ cout << " number of documents: " <<
+ doc_idx_.shape(0) / num_epochs << endl << std::flush;
+ cout << " number of epochs: " << num_epochs <<
+ endl << std::flush;
+ cout << " sequence length: " << seq_length <<
+ endl << std::flush;
+ cout << " total number of samples: " << num_samples <<
+ endl << std::flush;
+
+ // Index into sample_idx.
+ int64_t sample_index = 0;
+ // Index into doc_idx.
+ int64_t doc_idx_index = 0;
+ // Begining offset for each document.
+ int32_t doc_offset = 0;
+ // Start with first document and no offset.
+ sample_idx[2 * sample_index] = doc_idx_index;
+ sample_idx[2 * sample_index + 1] = doc_offset;
+ ++sample_index;
+
+ while (sample_index <= num_samples) {
+ // Start with a fresh sequence.
+ int32_t remaining_seq_length = seq_length + 1;
+ while (remaining_seq_length != 0) {
+ // Get the document length.
+ auto doc_id = doc_idx[doc_idx_index];
+ auto doc_length = sizes[doc_id] - doc_offset;
+ // And add it to the current sequence.
+ remaining_seq_length -= doc_length;
+ // If we have more than a full sequence, adjust offset and set
+ // remaining length to zero so we return from the while loop.
+ // Note that -1 here is for the same reason we have -1 in
+ // `_num_epochs` calculations.
+ if (remaining_seq_length <= 0) {
+ doc_offset += (remaining_seq_length + doc_length - 1);
+ remaining_seq_length = 0;
+ } else {
+ // Otherwise, start from the begining of the next document.
+ ++doc_idx_index;
+ doc_offset = 0;
+ }
+ }
+ // Record the sequence.
+ sample_idx[2 * sample_index] = doc_idx_index;
+ sample_idx[2 * sample_index + 1] = doc_offset;
+ ++sample_index;
+ }
+
+ // Method to deallocate memory.
+ py::capsule free_when_done(sample_idx, [](void *mem_) {
+ int32_t *mem = reinterpret_cast<int32_t*>(mem_);
+ delete[] mem;
+ });
+
+ // Return the numpy array.
+ const auto byte_size = sizeof(int32_t);
+ return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
+ {2*byte_size, byte_size}, // C-style contiguous strides
+ sample_idx, // the data pointer
+ free_when_done); // numpy array references
+
+}
+
+
+inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
+ const int32_t max_length,
+ std::mt19937& rand32_gen) {
+ /* Training sample length. */
+ const auto random_number = rand32_gen();
+ if ((random_number % short_seq_ratio) == 0) {
+ return 2 + random_number % (max_length - 1);
+ }
+ return max_length;
+}
+
+
+template<typename DocIdx>
+py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
+ const py::array_t<int32_t>& sizes_,
+ const int32_t num_epochs,
+ const uint64_t max_num_samples,
+ const int32_t max_seq_length,
+ const double short_seq_prob,
+ const int32_t seed,
+ const bool verbose) {
+ /* Build a mapping of (start-index, end-index, sequence-length) where
+ start and end index are the indices of the sentences in the sample
+ and sequence-length is the target sequence length.
+ */
+
+ // Consistency checks.
+ assert(num_epochs > 0);
+ assert(max_seq_length > 1);
+ assert(short_seq_prob > 0.0);
+ assert(short_seq_prob <= 1.0);
+ assert(seed > 0);
+
+ // Remove bound checks.
+ auto docs = docs_.unchecked<1>();
+ auto sizes = sizes_.unchecked<1>();
+
+ // For efficiency, convert probability to ratio. Note: rand() generates int.
+ const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
+
+ if (verbose) {
+ const auto sent_start_index = docs[0];
+ const auto sent_end_index = docs[docs_.shape(0) - 1];
+ const auto num_sentences = sent_end_index - sent_start_index;
+ cout << " using:" << endl << std::flush;
+ cout << " number of documents: " << docs_.shape(0) - 1 <<
+ endl << std::flush;
+ cout << " sentences range: [" << sent_start_index <<
+ ", " << sent_end_index << ")" << endl << std::flush;
+ cout << " total number of sentences: " << num_sentences <<
+ endl << std::flush;
+ cout << " number of epochs: " << num_epochs <<
+ endl << std::flush;
+ cout << " maximum number of samples: " << max_num_samples <<
+ endl << std::flush;
+ cout << " maximum sequence length: " << max_seq_length <<
+ endl << std::flush;
+ cout << " short sequence probability: " << short_seq_prob <<
+ endl << std::flush;
+ cout << " short sequence ration (1/prob): " << short_seq_ratio <<
+ endl << std::flush;
+ cout << " seed: " << seed << endl <<
+ std::flush;
+ }
+
+ // Mapping and it's length (1D).
+ int64_t num_samples = -1;
+ DocIdx* maps = NULL;
+
+ // Perform two iterations, in the first iteration get the size
+ // and allocate memory and in the second iteration populate the map.
+ bool second = false;
+ for (int32_t iteration=0; iteration<2; ++iteration) {
+
+ // Set the seed so both iterations produce the same results.
+ std::mt19937 rand32_gen(seed);
+
+ // Set the flag on second iteration.
+ second = (iteration == 1);
+
+ // Counters:
+ uint64_t empty_docs = 0;
+ uint64_t one_sent_docs = 0;
+ uint64_t long_sent_docs = 0;
+
+ // Current map index.
+ uint64_t map_index = 0;
+
+ // For each epoch:
+ for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
+ if (map_index >= max_num_samples) {
+ if (verbose && (!second)) {
+ cout << " reached " << max_num_samples << " samples after "
+ << epoch << " epochs ..." << endl << std::flush;
+ }
+ break;
+ }
+ // For each document:
+ for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
+
+ // Document sentences are in [sent_index_first, sent_index_last)
+ const auto sent_index_first = docs[doc];
+ const auto sent_index_last = docs[doc + 1];
+
+ // At the begining of the document previous index is the
+ // start index.
+ auto prev_start_index = sent_index_first;
+
+ // Remaining documents.
+ auto num_remain_sent = sent_index_last - sent_index_first;
+
+ // Some bookkeeping
+ if ((epoch == 0) && (!second)) {
+ if (num_remain_sent == 0) {
+ ++empty_docs;
+ }
+ if (num_remain_sent == 1) {
+ ++one_sent_docs;
+ }
+ }
+
+ // Detect documents with long sentences.
+ bool contains_long_sentence = false;
+ if (num_remain_sent > 1) {
+ for (auto sent_index=sent_index_first;
+ sent_index < sent_index_last; ++sent_index) {
+ if (sizes[sent_index] > LONG_SENTENCE_LEN){
+ if ((epoch == 0) && (!second)) {
+ ++long_sent_docs;
+ }
+ contains_long_sentence = true;
+ break;
+ }
+ }
+ }
+
+ // If we have more than two sentences.
+ if ((num_remain_sent > 1) && (!contains_long_sentence)) {
+
+ // Set values.
+ auto seq_len = int32_t{0};
+ auto num_sent = int32_t{0};
+ auto target_seq_len = get_target_sample_len(short_seq_ratio,
+ max_seq_length,
+ rand32_gen);
+
+ // Loop through sentences.
+ for (auto sent_index=sent_index_first;
+ sent_index < sent_index_last; ++sent_index) {
+
+ // Add the size and number of sentences.
+ seq_len += sizes[sent_index];
+ ++num_sent;
+ --num_remain_sent;
+
+ // If we have reached the target length.
+ // and if not only one sentence is left in the document.
+ // and if we have at least two sentneces.
+ // and if we have reached end of the document.
+ if (((seq_len >= target_seq_len) &&
+ (num_remain_sent > 1) &&
+ (num_sent > 1) ) || (num_remain_sent == 0)) {
+
+ // Check for overflow.
+ if ((3 * map_index + 2) >
+ std::numeric_limits<int64_t>::max()) {
+ cout << "number of samples exceeded maximum "
+ << "allowed by type int64: "
+ << std::numeric_limits<int64_t>::max()
+ << endl;
+ throw std::overflow_error("Number of samples");
+ }
+
+ // Populate the map.
+ if (second) {
+ const auto map_index_0 = 3 * map_index;
+ maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
+ maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
+ maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
+ }
+
+ // Update indices / counters.
+ ++map_index;
+ prev_start_index = sent_index + 1;
+ target_seq_len = get_target_sample_len(short_seq_ratio,
+ max_seq_length,
+ rand32_gen);
+ seq_len = 0;
+ num_sent = 0;
+ }
+
+ } // for (auto sent_index=sent_index_first; ...
+ } // if (num_remain_sent > 1) {
+ } // for (int doc=0; doc < num_docs; ++doc) {
+ } // for (int epoch=0; epoch < num_epochs; ++epoch) {
+
+ if (!second) {
+ if (verbose) {
+ cout << " number of empty documents: " << empty_docs <<
+ endl << std::flush;
+ cout << " number of documents with one sentence: " <<
+ one_sent_docs << endl << std::flush;
+ cout << " number of documents with long sentences: " <<
+ long_sent_docs << endl << std::flush;
+ cout << " will create mapping for " << map_index <<
+ " samples" << endl << std::flush;
+ }
+ assert(maps == NULL);
+ assert(num_samples < 0);
+ maps = new DocIdx[3*map_index];
+ num_samples = static_cast<int64_t>(map_index);
+ }
+
+ } // for (int iteration=0; iteration < 2; ++iteration) {
+
+ // Shuffle.
+ // We need a 64 bit random number generator as we might have more
+ // than 2 billion samples.
+ std::mt19937_64 rand64_gen(seed + 1);
+ for (auto i=(num_samples - 1); i > 0; --i) {
+ const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
+ const auto i0 = 3 * i;
+ const auto j0 = 3 * j;
+ // Swap values.
+ swap(maps[i0], maps[j0]);
+ swap(maps[i0 + 1], maps[j0 + 1]);
+ swap(maps[i0 + 2], maps[j0 + 2]);
+ }
+
+ // Method to deallocate memory.
+ py::capsule free_when_done(maps, [](void *mem_) {
+ DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
+ delete[] mem;
+ });
+
+ // Return the numpy array.
+ const auto byte_size = sizeof(DocIdx);
+ return py::array(std::vector<int64_t>{num_samples, 3}, // shape
+ {3*byte_size, byte_size}, // C-style contiguous strides
+ maps, // the data pointer
+ free_when_done); // numpy array references
+
+}
+
+
+py::array build_mapping(const py::array_t<int64_t>& docs_,
+ const py::array_t<int>& sizes_,
+ const int num_epochs,
+ const uint64_t max_num_samples,
+ const int max_seq_length,
+ const double short_seq_prob,
+ const int seed,
+ const bool verbose) {
+
+ if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
+ if (verbose) {
+ cout << " using uint64 for data mapping..." << endl << std::flush;
+ }
+ return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
+ max_num_samples, max_seq_length,
+ short_seq_prob, seed, verbose);
+ } else {
+ if (verbose) {
+ cout << " using uint32 for data mapping..." << endl << std::flush;
+ }
+ return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
+ max_num_samples, max_seq_length,
+ short_seq_prob, seed, verbose);
+ }
+}
+
+template<typename DocIdx>
+py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
+ const py::array_t<int32_t>& sizes_,
+ const py::array_t<int32_t>& titles_sizes_,
+ const int32_t num_epochs,
+ const uint64_t max_num_samples,
+ const int32_t max_seq_length,
+ const int32_t seed,
+ const bool verbose,
+ const bool use_one_sent_blocks) {
+ /* Build a mapping of (start-index, end-index, sequence-length) where
+ start and end index are the indices of the sentences in the sample
+ and sequence-length is the target sequence length.
+ */
+
+ // Consistency checks.
+ assert(num_epochs > 0);
+ assert(max_seq_length > 1);
+ assert(seed > 0);
+
+ // Remove bound checks.
+ auto docs = docs_.unchecked<1>();
+ auto sizes = sizes_.unchecked<1>();
+ auto titles_sizes = titles_sizes_.unchecked<1>();
+
+ if (verbose) {
+ const auto sent_start_index = docs[0];
+ const auto sent_end_index = docs[docs_.shape(0) - 1];
+ const auto num_sentences = sent_end_index - sent_start_index;
+ cout << " using:" << endl << std::flush;
+ cout << " number of documents: " << docs_.shape(0) - 1 <<
+ endl << std::flush;
+ cout << " sentences range: [" << sent_start_index <<
+ ", " << sent_end_index << ")" << endl << std::flush;
+ cout << " total number of sentences: " << num_sentences <<
+ endl << std::flush;
+ cout << " number of epochs: " << num_epochs <<
+ endl << std::flush;
+ cout << " maximum number of samples: " << max_num_samples <<
+ endl << std::flush;
+ cout << " maximum sequence length: " << max_seq_length <<
+ endl << std::flush;
+ cout << " seed: " << seed << endl <<
+ std::flush;
+ }
+
+ // Mapping and its length (1D).
+ int64_t num_samples = -1;
+ DocIdx* maps = NULL;
+
+ // Acceptable number of sentences per block.
+ int min_num_sent = 2;
+ if (use_one_sent_blocks) {
+ min_num_sent = 1;
+ }
+
+ // Perform two iterations, in the first iteration get the size
+ // and allocate memory and in the second iteration populate the map.
+ bool second = false;
+ for (int32_t iteration=0; iteration<2; ++iteration) {
+
+ // Set the flag on second iteration.
+ second = (iteration == 1);
+
+ // Current map index.
+ uint64_t map_index = 0;
+
+ uint64_t empty_docs = 0;
+ uint64_t one_sent_docs = 0;
+ uint64_t long_sent_docs = 0;
+ // For each epoch:
+ for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
+ // assign every block a unique id
+ int32_t block_id = 0;
+
+ if (map_index >= max_num_samples) {
+ if (verbose && (!second)) {
+ cout << " reached " << max_num_samples << " samples after "
+ << epoch << " epochs ..." << endl << std::flush;
+ }
+ break;
+ }
+ // For each document:
+ for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
+
+ // Document sentences are in [sent_index_first, sent_index_last)
+ const auto sent_index_first = docs[doc];
+ const auto sent_index_last = docs[doc + 1];
+ const auto target_seq_len = max_seq_length - titles_sizes[doc];
+
+ // At the begining of the document previous index is the
+ // start index.
+ auto prev_start_index = sent_index_first;
+
+ // Remaining documents.
+ auto num_remain_sent = sent_index_last - sent_index_first;
+
+ // Some bookkeeping
+ if ((epoch == 0) && (!second)) {
+ if (num_remain_sent == 0) {
+ ++empty_docs;
+ }
+ if (num_remain_sent == 1) {
+ ++one_sent_docs;
+ }
+ }
+ // Detect documents with long sentences.
+ bool contains_long_sentence = false;
+ if (num_remain_sent >= min_num_sent) {
+ for (auto sent_index=sent_index_first;
+ sent_index < sent_index_last; ++sent_index) {
+ if (sizes[sent_index] > LONG_SENTENCE_LEN){
+ if ((epoch == 0) && (!second)) {
+ ++long_sent_docs;
+ }
+ contains_long_sentence = true;
+ break;
+ }
+ }
+ }
+ // If we have enough sentences and no long sentences.
+ if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
+
+ // Set values.
+ auto seq_len = int32_t{0};
+ auto num_sent = int32_t{0};
+
+ // Loop through sentences.
+ for (auto sent_index=sent_index_first;
+ sent_index < sent_index_last; ++sent_index) {
+
+ // Add the size and number of sentences.
+ seq_len += sizes[sent_index];
+ ++num_sent;
+ --num_remain_sent;
+
+ // If we have reached the target length.
+ // and there are an acceptable number of sentences left
+ // and if we have at least the minimum number of sentences.
+ // or if we have reached end of the document.
+ if (((seq_len >= target_seq_len) &&
+ (num_remain_sent >= min_num_sent) &&
+ (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
+
+ // Populate the map.
+ if (second) {
+ const auto map_index_0 = 4 * map_index;
+ // Each sample has 4 items: the starting sentence index, ending sentence index,
+ // the index of the document from which the block comes (used for fetching titles)
+ // and the unique id of the block (used for creating block indexes)
+
+ maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
+ maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
+ maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
+ maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
+ }
+
+ // Update indices / counters.
+ ++map_index;
+ ++block_id;
+ prev_start_index = sent_index + 1;
+ seq_len = 0;
+ num_sent = 0;
+ }
+ } // for (auto sent_index=sent_index_first; ...
+ } // if (num_remain_sent > 1) {
+ } // for (int doc=0; doc < num_docs; ++doc) {
+ } // for (int epoch=0; epoch < num_epochs; ++epoch) {
+
+ if (!second) {
+ if (verbose) {
+ cout << " number of empty documents: " << empty_docs <<
+ endl << std::flush;
+ cout << " number of documents with one sentence: " <<
+ one_sent_docs << endl << std::flush;
+ cout << " number of documents with long sentences: " <<
+ long_sent_docs << endl << std::flush;
+ cout << " will create mapping for " << map_index <<
+ " samples" << endl << std::flush;
+ }
+ assert(maps == NULL);
+ assert(num_samples < 0);
+ maps = new DocIdx[4*map_index];
+ num_samples = static_cast<int64_t>(map_index);
+ }
+
+ } // for (int iteration=0; iteration < 2; ++iteration) {
+
+ // Shuffle.
+ // We need a 64 bit random number generator as we might have more
+ // than 2 billion samples.
+ std::mt19937_64 rand64_gen(seed + 1);
+ for (auto i=(num_samples - 1); i > 0; --i) {
+ const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
+ const auto i0 = 4 * i;
+ const auto j0 = 4 * j;
+ // Swap values.
+ swap(maps[i0], maps[j0]);
+ swap(maps[i0 + 1], maps[j0 + 1]);
+ swap(maps[i0 + 2], maps[j0 + 2]);
+ swap(maps[i0 + 3], maps[j0 + 3]);
+ }
+
+ // Method to deallocate memory.
+ py::capsule free_when_done(maps, [](void *mem_) {
+ DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
+ delete[] mem;
+ });
+
+ // Return the numpy array.
+ const auto byte_size = sizeof(DocIdx);
+ return py::array(std::vector<int64_t>{num_samples, 4}, // shape
+ {4*byte_size, byte_size}, // C-style contiguous strides
+ maps, // the data pointer
+ free_when_done); // numpy array references
+
+}
+
+py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
+ const py::array_t<int>& sizes_,
+ const py::array_t<int>& titles_sizes_,
+ const int num_epochs,
+ const uint64_t max_num_samples,
+ const int max_seq_length,
+ const int seed,
+ const bool verbose,
+ const bool use_one_sent_blocks) {
+
+ if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
+ if (verbose) {
+ cout << " using uint64 for data mapping..." << endl << std::flush;
+ }
+ return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
+ num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
+ } else {
+ if (verbose) {
+ cout << " using uint32 for data mapping..." << endl << std::flush;
+ }
+ return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
+ num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
+ }
+}
+
+PYBIND11_MODULE(helpers, m) {
+ m.def("build_mapping", &build_mapping);
+ m.def("build_blocks_mapping", &build_blocks_mapping);
+ m.def("build_sample_idx", &build_sample_idx);
+}
diff --git a/megatron_lm/megatron/data/ict_dataset.py b/megatron_lm/megatron/data/ict_dataset.py
new file mode 100644
index 0000000..71916d6
--- /dev/null
+++ b/megatron_lm/megatron/data/ict_dataset.py
@@ -0,0 +1,140 @@
+import itertools
+import random
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from megatron import get_tokenizer
+from megatron import get_args
+from megatron.data.dataset_utils import get_indexed_dataset_
+from megatron.data.realm_dataset_utils import get_block_samples_mapping
+
+
+def get_ict_dataset(use_titles=True, query_in_block_prob=1):
+ """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
+ rather than for training, since it is only built with a single epoch sample mapping.
+ """
+ args = get_args()
+ block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
+ titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
+
+ kwargs = dict(
+ name='full',
+ block_dataset=block_dataset,
+ title_dataset=titles_dataset,
+ data_prefix=args.data_path,
+ num_epochs=1,
+ max_num_samples=None,
+ max_seq_length=args.seq_length,
+ seed=1,
+ query_in_block_prob=query_in_block_prob,
+ use_titles=use_titles,
+ use_one_sent_docs=args.use_one_sent_docs
+ )
+ dataset = ICTDataset(**kwargs)
+ return dataset
+
+
+class ICTDataset(Dataset):
+ """Dataset containing sentences and their blocks for an inverse cloze task."""
+ def __init__(self, name, block_dataset, title_dataset, data_prefix,
+ num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
+ seed, use_titles=True, use_one_sent_docs=False):
+ self.name = name
+ self.seed = seed
+ self.max_seq_length = max_seq_length
+ self.query_in_block_prob = query_in_block_prob
+ self.block_dataset = block_dataset
+ self.title_dataset = title_dataset
+ self.rng = random.Random(self.seed)
+ self.use_titles = use_titles
+ self.use_one_sent_docs = use_one_sent_docs
+
+ self.samples_mapping = get_block_samples_mapping(
+ block_dataset, title_dataset, data_prefix, num_epochs,
+ max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
+ self.tokenizer = get_tokenizer()
+ self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
+ self.vocab_id_to_token_list = self.tokenizer.inv_vocab
+ self.cls_id = self.tokenizer.cls
+ self.sep_id = self.tokenizer.sep
+ self.mask_id = self.tokenizer.mask
+ self.pad_id = self.tokenizer.pad
+
+ def __len__(self):
+ return len(self.samples_mapping)
+
+ def __getitem__(self, idx):
+ """Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
+ sample_data = self.samples_mapping[idx]
+ start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
+
+ if self.use_titles:
+ title = self.title_dataset[int(doc_idx)]
+ title_pad_offset = 3 + len(title)
+ else:
+ title = None
+ title_pad_offset = 2
+ block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
+ assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
+
+ # randint() is inclusive for Python rng
+ rand_sent_idx = self.rng.randint(0, len(block) - 1)
+
+ # keep the query in the context query_in_block_prob fraction of the time.
+ if self.rng.random() < self.query_in_block_prob:
+ query = block[rand_sent_idx].copy()
+ else:
+ query = block.pop(rand_sent_idx)
+
+ # still need to truncate because blocks are concluded when
+ # the sentence lengths have exceeded max_seq_length.
+ query = query[:self.max_seq_length - 2]
+ block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
+
+ query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
+ block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
+ block_data = sample_data.as_array()
+
+ sample = {
+ 'query_tokens': query_tokens,
+ 'query_pad_mask': query_pad_mask,
+ 'block_tokens': block_tokens,
+ 'block_pad_mask': block_pad_mask,
+ 'block_data': block_data,
+ }
+
+ return sample
+
+ def get_block(self, start_idx, end_idx, doc_idx):
+ """Get the IDs for an evidence block plus the title of the corresponding document"""
+ block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
+ title = self.title_dataset[int(doc_idx)]
+
+ block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
+ block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
+
+ return block_tokens, block_pad_mask
+
+ def get_null_block(self):
+ """Get empty block and title - used in REALM pretraining"""
+ block, title = [], []
+ block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
+
+ return block_tokens, block_pad_mask
+
+ def concat_and_pad_tokens(self, tokens, title=None):
+ """Concat with special tokens and pad sequence to self.max_seq_length"""
+ tokens = list(tokens)
+ if title is None:
+ tokens = [self.cls_id] + tokens + [self.sep_id]
+ else:
+ title = list(title)
+ tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
+ assert len(tokens) <= self.max_seq_length
+
+ num_pad = self.max_seq_length - len(tokens)
+ pad_mask = [1] * len(tokens) + [0] * num_pad
+ tokens += [self.pad_id] * num_pad
+
+ return np.array(tokens), np.array(pad_mask)
diff --git a/megatron_lm/megatron/data/indexed_dataset.py b/megatron_lm/megatron/data/indexed_dataset.py
new file mode 100644
index 0000000..1251066
--- /dev/null
+++ b/megatron_lm/megatron/data/indexed_dataset.py
@@ -0,0 +1,570 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# copied from fairseq/fairseq/data/indexed_dataset.py
+# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
+# other slight modifications to remove fairseq dependencies
+# Added document index to index file and made it accessible.
+# An empty sentence no longer separates documents.
+
+from functools import lru_cache
+import os
+import shutil
+import struct
+from itertools import accumulate
+
+import numpy as np
+import torch
+from megatron import print_rank_0
+
+
+def __best_fitting_dtype(vocab_size=None):
+ if vocab_size is not None and vocab_size < 65500:
+ return np.uint16
+ else:
+ return np.int32
+
+
+def get_available_dataset_impl():
+ return ['lazy', 'cached', 'mmap']
+
+
+def infer_dataset_impl(path):
+ if IndexedDataset.exists(path):
+ with open(index_file_path(path), 'rb') as f:
+ magic = f.read(8)
+ if magic == IndexedDataset._HDR_MAGIC:
+ return 'cached'
+ elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
+ return 'mmap'
+ else:
+ return None
+ else:
+ print(f"Dataset does not exist: {path}")
+ print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
+ return None
+
+
+def make_builder(out_file, impl, vocab_size=None):
+ if impl == 'mmap':
+ return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
+ else:
+ return IndexedDatasetBuilder(out_file)
+
+
+def make_dataset(path, impl, skip_warmup=False):
+ if not IndexedDataset.exists(path):
+ print(f"Dataset does not exist: {path}")
+ print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
+ return None
+ if impl == 'infer':
+ impl = infer_dataset_impl(path)
+ if impl == 'lazy' and IndexedDataset.exists(path):
+ return IndexedDataset(path)
+ elif impl == 'cached' and IndexedDataset.exists(path):
+ return IndexedCachedDataset(path)
+ elif impl == 'mmap' and MMapIndexedDataset.exists(path):
+ return MMapIndexedDataset(path, skip_warmup)
+ print(f"Unknown dataset implementation: {impl}")
+ return None
+
+
+def dataset_exists(path, impl):
+ if impl == 'mmap':
+ return MMapIndexedDataset.exists(path)
+ else:
+ return IndexedDataset.exists(path)
+
+
+def read_longs(f, n):
+ a = np.empty(n, dtype=np.int64)
+ f.readinto(a)
+ return a
+
+
+def write_longs(f, a):
+ f.write(np.array(a, dtype=np.int64))
+
+
+dtypes = {
+ 1: np.uint8,
+ 2: np.int8,
+ 3: np.int16,
+ 4: np.int32,
+ 5: np.int64,
+ 6: np.float,
+ 7: np.double,
+ 8: np.uint16
+}
+
+
+def code(dtype):
+ for k in dtypes.keys():
+ if dtypes[k] == dtype:
+ return k
+ raise ValueError(dtype)
+
+
+def index_file_path(prefix_path):
+ return prefix_path + '.idx'
+
+
+def data_file_path(prefix_path):
+ return prefix_path + '.bin'
+
+
+def create_doc_idx(sizes):
+ doc_idx = [0]
+ for i, s in enumerate(sizes):
+ if s == 0:
+ doc_idx.append(i + 1)
+ return doc_idx
+
+
+class IndexedDataset(torch.utils.data.Dataset):
+ """Loader for IndexedDataset"""
+ _HDR_MAGIC = b'TNTIDX\x00\x00'
+
+ def __init__(self, path):
+ super().__init__()
+ self.path = path
+ self.data_file = None
+ self.read_index(path)
+
+ def read_index(self, path):
+ with open(index_file_path(path), 'rb') as f:
+ magic = f.read(8)
+ assert magic == self._HDR_MAGIC, (
+ 'Index file doesn\'t match expected format. '
+ 'Make sure that --dataset-impl is configured properly.'
+ )
+ version = f.read(8)
+ assert struct.unpack('<Q', version) == (1,)
+ code, self.element_size = struct.unpack('<QQ', f.read(16))
+ self.dtype = dtypes[code]
+ self._len, self.s = struct.unpack('<QQ', f.read(16))
+ self.doc_count = struct.unpack('<Q', f.read(8))
+ self.dim_offsets = read_longs(f, self._len + 1)
+ self.data_offsets = read_longs(f, self._len + 1)
+ self.sizes = read_longs(f, self.s)
+ self.doc_idx = read_longs(f, self.doc_count)
+
+ def read_data(self, path):
+ self.data_file = open(data_file_path(path), 'rb', buffering=0)
+
+ def check_index(self, i):
+ if i < 0 or i >= self._len:
+ raise IndexError('index out of range')
+
+ def __del__(self):
+ if self.data_file:
+ self.data_file.close()
+
+ # @lru_cache(maxsize=8)
+ def __getitem__(self, idx):
+ if not self.data_file:
+ self.read_data(self.path)
+ if isinstance(idx, int):
+ i = idx
+ self.check_index(i)
+ tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
+ a = np.empty(tensor_size, dtype=self.dtype)
+ self.data_file.seek(self.data_offsets[i] * self.element_size)
+ self.data_file.readinto(a)
+ return a
+ elif isinstance(idx, slice):
+ start, stop, step = idx.indices(len(self))
+ if step != 1:
+ raise ValueError("Slices into indexed_dataset must be contiguous")
+ sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
+ size = sum(sizes)
+ a = np.empty(size, dtype=self.dtype)
+ self.data_file.seek(self.data_offsets[start] * self.element_size)
+ self.data_file.readinto(a)
+ offsets = list(accumulate(sizes))
+ sents = np.split(a, offsets[:-1])
+ return sents
+
+ def __len__(self):
+ return self._len
+
+ def num_tokens(self, index):
+ return self.sizes[index]
+
+ def size(self, index):
+ return self.sizes[index]
+
+ @staticmethod
+ def exists(path):
+ return (
+ os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
+ )
+
+ @property
+ def supports_prefetch(self):
+ return False # avoid prefetching to save memory
+
+
+class IndexedCachedDataset(IndexedDataset):
+
+ def __init__(self, path):
+ super().__init__(path)
+ self.cache = None
+ self.cache_index = {}
+
+ @property
+ def supports_prefetch(self):
+ return True
+
+ def prefetch(self, indices):
+ if all(i in self.cache_index for i in indices):
+ return
+ if not self.data_file:
+ self.read_data(self.path)
+ indices = sorted(set(indices))
+ total_size = 0
+ for i in indices:
+ total_size += self.data_offsets[i + 1] - self.data_offsets[i]
+ self.cache = np.empty(total_size, dtype=self.dtype)
+ ptx = 0
+ self.cache_index.clear()
+ for i in indices:
+ self.cache_index[i] = ptx
+ size = self.data_offsets[i + 1] - self.data_offsets[i]
+ a = self.cache[ptx: ptx + size]
+ self.data_file.seek(self.data_offsets[i] * self.element_size)
+ self.data_file.readinto(a)
+ ptx += size
+ if self.data_file:
+ # close and delete data file after prefetch so we can pickle
+ self.data_file.close()
+ self.data_file = None
+
+ # @lru_cache(maxsize=8)
+ def __getitem__(self, idx):
+ if isinstance(idx, int):
+ i = idx
+ self.check_index(i)
+ tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
+ a = np.empty(tensor_size, dtype=self.dtype)
+ ptx = self.cache_index[i]
+ np.copyto(a, self.cache[ptx: ptx + a.size])
+ return a
+ elif isinstance(idx, slice):
+ # Hack just to make this work, can optimizer later if necessary
+ sents = []
+ for i in range(*idx.indices(len(self))):
+ sents.append(self[i])
+ return sents
+
+
+class IndexedDatasetBuilder(object):
+ element_sizes = {
+ np.uint8: 1,
+ np.int8: 1,
+ np.int16: 2,
+ np.int32: 4,
+ np.int64: 8,
+ np.float: 4,
+ np.double: 8
+ }
+
+ def __init__(self, out_file, dtype=np.int32):
+ self.out_file = open(out_file, 'wb')
+ self.dtype = dtype
+ self.data_offsets = [0]
+ self.dim_offsets = [0]
+ self.sizes = []
+ self.element_size = self.element_sizes[self.dtype]
+ self.doc_idx = [0]
+
+ def add_item(self, tensor):
+ bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
+ self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
+ for s in tensor.size():
+ self.sizes.append(s)
+ self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
+
+ def end_document(self):
+ self.doc_idx.append(len(self.sizes))
+
+ def merge_file_(self, another_file):
+ index = IndexedDataset(another_file)
+ assert index.dtype == self.dtype
+
+ begin = self.data_offsets[-1]
+ for offset in index.data_offsets[1:]:
+ self.data_offsets.append(begin + offset)
+ self.sizes.extend(index.sizes)
+ begin = self.dim_offsets[-1]
+ for dim_offset in index.dim_offsets[1:]:
+ self.dim_offsets.append(begin + dim_offset)
+
+ with open(data_file_path(another_file), 'rb') as f:
+ while True:
+ data = f.read(1024)
+ if data:
+ self.out_file.write(data)
+ else:
+ break
+
+ def finalize(self, index_file):
+ self.out_file.close()
+ index = open(index_file, 'wb')
+ index.write(b'TNTIDX\x00\x00')
+ index.write(struct.pack('<Q', 1))
+ index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
+ index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
+ index.write(struct.pack('<Q', len(self.doc_idx)))
+ write_longs(index, self.dim_offsets)
+ write_longs(index, self.data_offsets)
+ write_longs(index, self.sizes)
+ write_longs(index, self.doc_idx)
+ index.close()
+
+
+def _warmup_mmap_file(path):
+ with open(path, 'rb') as stream:
+ while stream.read(100 * 1024 * 1024):
+ pass
+
+
+class MMapIndexedDataset(torch.utils.data.Dataset):
+ class Index(object):
+ _HDR_MAGIC = b'MMIDIDX\x00\x00'
+
+ @classmethod
+ def writer(cls, path, dtype):
+ class _Writer(object):
+ def __enter__(self):
+ self._file = open(path, 'wb')
+
+ self._file.write(cls._HDR_MAGIC)
+ self._file.write(struct.pack('<Q', 1))
+ self._file.write(struct.pack('<B', code(dtype)))
+
+ return self
+
+ @staticmethod
+ def _get_pointers(sizes):
+ dtype_size = dtype().itemsize
+ address = 0
+ pointers = []
+
+ for size in sizes:
+ pointers.append(address)
+ address += size * dtype_size
+
+ return pointers
+
+ def write(self, sizes, doc_idx):
+ pointers = self._get_pointers(sizes)
+
+ self._file.write(struct.pack('<Q', len(sizes)))
+ self._file.write(struct.pack('<Q', len(doc_idx)))
+
+ sizes = np.array(sizes, dtype=np.int32)
+ self._file.write(sizes.tobytes(order='C'))
+ del sizes
+
+ pointers = np.array(pointers, dtype=np.int64)
+ self._file.write(pointers.tobytes(order='C'))
+ del pointers
+
+ doc_idx = np.array(doc_idx, dtype=np.int64)
+ self._file.write(doc_idx.tobytes(order='C'))
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._file.close()
+
+ return _Writer()
+
+ def __init__(self, path, skip_warmup=False):
+ with open(path, 'rb') as stream:
+ magic_test = stream.read(9)
+ assert self._HDR_MAGIC == magic_test, (
+ 'Index file doesn\'t match expected format. '
+ 'Make sure that --dataset-impl is configured properly.'
+ )
+ version = struct.unpack('<Q', stream.read(8))
+ assert (1,) == version
+
+ dtype_code, = struct.unpack('<B', stream.read(1))
+ self._dtype = dtypes[dtype_code]
+ self._dtype_size = self._dtype().itemsize
+
+ self._len = struct.unpack('<Q', stream.read(8))[0]
+ self._doc_count = struct.unpack('<Q', stream.read(8))[0]
+ offset = stream.tell()
+
+ if not skip_warmup:
+ print_rank_0(" warming up index mmap file...")
+ _warmup_mmap_file(path)
+
+ self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
+ print_rank_0(" reading sizes...")
+ self._sizes = np.frombuffer(
+ self._bin_buffer,
+ dtype=np.int32,
+ count=self._len,
+ offset=offset)
+ print_rank_0(" reading pointers...")
+ self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
+ offset=offset + self._sizes.nbytes)
+ print_rank_0(" reading document index...")
+ self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
+ offset=offset + self._sizes.nbytes + self._pointers.nbytes)
+
+ def __del__(self):
+ self._bin_buffer_mmap._mmap.close()
+ del self._bin_buffer_mmap
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @property
+ def sizes(self):
+ return self._sizes
+
+ @property
+ def doc_idx(self):
+ return self._doc_idx
+
+ @lru_cache(maxsize=8)
+ def __getitem__(self, i):
+ return self._pointers[i], self._sizes[i]
+
+ def __len__(self):
+ return self._len
+
+ def __init__(self, path, skip_warmup=False):
+ super().__init__()
+
+ self._path = None
+ self._index = None
+ self._bin_buffer = None
+
+ self._do_init(path, skip_warmup)
+
+ def __getstate__(self):
+ return self._path
+
+ def __setstate__(self, state):
+ self._do_init(state)
+
+ def _do_init(self, path, skip_warmup):
+ self._path = path
+ self._index = self.Index(index_file_path(self._path), skip_warmup)
+
+ if not skip_warmup:
+ print_rank_0(" warming up data mmap file...")
+ _warmup_mmap_file(data_file_path(self._path))
+ print_rank_0(" creating numpy buffer of mmap...")
+ self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
+ print_rank_0(" creating memory view of numpy buffer...")
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
+
+ def __del__(self):
+ self._bin_buffer_mmap._mmap.close()
+ del self._bin_buffer_mmap
+ del self._index
+
+ def __len__(self):
+ return len(self._index)
+
+ # @lru_cache(maxsize=8)
+ def __getitem__(self, idx):
+ if isinstance(idx, int):
+ ptr, size = self._index[idx]
+ np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
+ count=size, offset=ptr)
+ return np_array
+ elif isinstance(idx, slice):
+ start, stop, step = idx.indices(len(self))
+ if step != 1:
+ raise ValueError("Slices into indexed_dataset must be contiguous")
+ ptr = self._index._pointers[start]
+ sizes = self._index._sizes[idx]
+ offsets = list(accumulate(sizes))
+ total_size = sum(sizes)
+ np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
+ count=total_size, offset=ptr)
+ sents = np.split(np_array, offsets[:-1])
+ return sents
+
+ def get(self, idx, offset=0, length=None):
+ """ Retrieves a single item from the dataset with the option to only
+ return a portion of the item.
+
+ get(idx) is the same as [idx] but get() does not support slicing.
+ """
+ ptr, size = self._index[idx]
+ if length is None:
+ length = size - offset
+ ptr += offset * np.dtype(self._index.dtype).itemsize
+ np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
+ count=length, offset=ptr)
+ return np_array
+
+ @property
+ def sizes(self):
+ return self._index.sizes
+
+ @property
+ def doc_idx(self):
+ return self._index.doc_idx
+
+ def get_doc_idx(self):
+ return self._index._doc_idx
+
+ def set_doc_idx(self, doc_idx_):
+ self._index._doc_idx = doc_idx_
+
+ @property
+ def supports_prefetch(self):
+ return False
+
+ @staticmethod
+ def exists(path):
+ return (
+ os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
+ )
+
+
+class MMapIndexedDatasetBuilder(object):
+ def __init__(self, out_file, dtype=np.int64):
+ self._data_file = open(out_file, 'wb')
+ self._dtype = dtype
+ self._sizes = []
+ self._doc_idx = [0]
+
+ def add_item(self, tensor):
+ np_array = np.array(tensor.numpy(), dtype=self._dtype)
+ self._data_file.write(np_array.tobytes(order='C'))
+ self._sizes.append(np_array.size)
+
+ def end_document(self):
+ self._doc_idx.append(len(self._sizes))
+
+ def merge_file_(self, another_file):
+ # Concatenate index
+ index = MMapIndexedDataset.Index(index_file_path(another_file))
+ assert index.dtype == self._dtype
+
+ for size in index.sizes:
+ self._sizes.append(size)
+
+ # Concatenate data
+ with open(data_file_path(another_file), 'rb') as f:
+ shutil.copyfileobj(f, self._data_file)
+
+ def finalize(self, index_file):
+ self._data_file.close()
+
+ with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
+ index.write(self._sizes, self._doc_idx)
diff --git a/megatron_lm/megatron/data/realm_dataset_utils.py b/megatron_lm/megatron/data/realm_dataset_utils.py
new file mode 100644
index 0000000..68aed4a
--- /dev/null
+++ b/megatron_lm/megatron/data/realm_dataset_utils.py
@@ -0,0 +1,201 @@
+import os
+import time
+
+import numpy as np
+import torch
+
+from megatron import mpu, print_rank_0
+from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
+from megatron.data.samplers import DistributedBatchSampler
+from megatron import get_args, get_tokenizer, print_rank_0, mpu
+
+
+def get_one_epoch_dataloader(dataset, batch_size=None):
+ """Specifically one epoch to be used in an indexing job."""
+ args = get_args()
+
+ world_size = mpu.get_data_parallel_world_size()
+ rank = mpu.get_data_parallel_rank()
+ if batch_size is None:
+ batch_size = args.batch_size
+ global_batch_size = batch_size * world_size
+ num_workers = args.num_workers
+
+ sampler = torch.utils.data.SequentialSampler(dataset)
+ # importantly, drop_last must be False to get all the data.
+ batch_sampler = DistributedBatchSampler(sampler,
+ batch_size=global_batch_size,
+ drop_last=False,
+ rank=rank,
+ world_size=world_size)
+
+ return torch.utils.data.DataLoader(dataset,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers,
+ pin_memory=True)
+
+
+def get_ict_batch(data_iterator):
+ # Items and their type.
+ keys = ['query_tokens', 'query_pad_mask',
+ 'block_tokens', 'block_pad_mask', 'block_data']
+ datatype = torch.int64
+
+ # Broadcast data.
+ if data_iterator is None:
+ data = None
+ else:
+ data = next(data_iterator)
+ data_b = mpu.broadcast_data(keys, data, datatype)
+
+ # Unpack.
+ query_tokens = data_b['query_tokens'].long()
+ query_pad_mask = data_b['query_pad_mask'].long()
+ block_tokens = data_b['block_tokens'].long()
+ block_pad_mask = data_b['block_pad_mask'].long()
+ block_indices = data_b['block_data'].long()
+
+ return query_tokens, query_pad_mask,\
+ block_tokens, block_pad_mask, block_indices
+
+
+def join_str_list(str_list):
+ """Join a list of strings, handling spaces appropriately"""
+ result = ""
+ for s in str_list:
+ if s.startswith("##"):
+ result += s[2:]
+ else:
+ result += " " + s
+ return result
+
+
+class BlockSampleData(object):
+ """A struct for fully describing a fixed-size block of data as used in REALM
+
+ :param start_idx: for first sentence of the block
+ :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
+ :param doc_idx: the index of the document from which the block comes in the original indexed dataset
+ :param block_idx: a unique integer identifier given to every block.
+ """
+ def __init__(self, start_idx, end_idx, doc_idx, block_idx):
+ self.start_idx = start_idx
+ self.end_idx = end_idx
+ self.doc_idx = doc_idx
+ self.block_idx = block_idx
+
+ def as_array(self):
+ return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
+
+ def as_tuple(self):
+ return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
+
+
+class BlockSamplesMapping(object):
+ def __init__(self, mapping_array):
+ # make sure that the array is compatible with BlockSampleData
+ assert mapping_array.shape[1] == 4
+ self.mapping_array = mapping_array
+
+ def __len__(self):
+ return self.mapping_array.shape[0]
+
+ def __getitem__(self, idx):
+ """Get the data associated with an indexed sample."""
+ sample_data = BlockSampleData(*self.mapping_array[idx])
+ return sample_data
+
+
+def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
+ max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
+ """Get samples mapping for a dataset over fixed size blocks. This function also requires
+ a dataset of the titles for the source documents since their lengths must be taken into account.
+
+ :return: samples_mapping (BlockSamplesMapping)
+ """
+
+ if not num_epochs:
+ if not max_num_samples:
+ raise ValueError("Need to specify either max_num_samples "
+ "or num_epochs")
+ num_epochs = np.iinfo(np.int32).max - 1
+ if not max_num_samples:
+ max_num_samples = np.iinfo(np.int64).max - 1
+
+ # Filename of the index mapping
+ indexmap_filename = data_prefix
+ indexmap_filename += '_{}_indexmap'.format(name)
+ if num_epochs != (np.iinfo(np.int32).max - 1):
+ indexmap_filename += '_{}ep'.format(num_epochs)
+ if max_num_samples != (np.iinfo(np.int64).max - 1):
+ indexmap_filename += '_{}mns'.format(max_num_samples)
+ indexmap_filename += '_{}msl'.format(max_seq_length)
+ indexmap_filename += '_{}s'.format(seed)
+ if use_one_sent_docs:
+ indexmap_filename += '_1sentok'
+ indexmap_filename += '.npy'
+
+ # Build the indexed mapping if not exist.
+ if mpu.get_data_parallel_rank() == 0 and \
+ not os.path.isfile(indexmap_filename):
+ print(' > WARNING: could not find index map file {}, building '
+ 'the indices on rank 0 ...'.format(indexmap_filename))
+
+ # Make sure the types match the helpers input types.
+ assert block_dataset.doc_idx.dtype == np.int64
+ assert block_dataset.sizes.dtype == np.int32
+
+ # Build samples mapping
+ verbose = torch.distributed.get_rank() == 0
+ start_time = time.time()
+ print_rank_0(' > building samples index mapping for {} ...'.format(
+ name))
+
+ # compile/bind the C++ helper code
+ from megatron.data.dataset_utils import compile_helper
+ compile_helper()
+
+ from megatron.data import helpers
+ mapping_array = helpers.build_blocks_mapping(
+ block_dataset.doc_idx,
+ block_dataset.sizes,
+ title_dataset.sizes,
+ num_epochs,
+ max_num_samples,
+ max_seq_length - 3, # account for added tokens
+ seed,
+ verbose,
+ use_one_sent_docs)
+
+
+ print_rank_0(' > done building samples index mapping')
+ np.save(indexmap_filename, mapping_array, allow_pickle=True)
+ print_rank_0(' > saved the index mapping in {}'.format(
+ indexmap_filename))
+ # Make sure all the ranks have built the mapping
+ print_rank_0(' > elapsed time to build and save samples mapping '
+ '(seconds): {:4f}'.format(
+ time.time() - start_time))
+
+ # This should be a barrier but nccl barrier assumes
+ # device_index=rank which is not the case for model
+ # parallel case
+ counts = torch.cuda.LongTensor([1])
+ torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
+ assert counts[0].item() == torch.distributed.get_world_size(
+ group=mpu.get_data_parallel_group())
+
+ # Load indexed dataset.
+ print_rank_0(' > loading indexed mapping from {}'.format(
+ indexmap_filename))
+ start_time = time.time()
+
+ mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
+ samples_mapping = BlockSamplesMapping(mapping_array)
+
+ print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
+ time.time() - start_time))
+ print_rank_0(' total number of samples: {}'.format(
+ mapping_array.shape[0]))
+
+ return samples_mapping
diff --git a/megatron_lm/megatron/data/realm_index.py b/megatron_lm/megatron/data/realm_index.py
new file mode 100644
index 0000000..54344e0
--- /dev/null
+++ b/megatron_lm/megatron/data/realm_index.py
@@ -0,0 +1,216 @@
+import itertools
+import os
+import pickle
+import shutil
+
+import numpy as np
+import torch
+
+from megatron import get_args
+from megatron import mpu
+
+
+def detach(tensor):
+ return tensor.detach().cpu().numpy()
+
+
+class BlockData(object):
+ """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
+ def __init__(self, block_data_path=None, load_from_path=True, rank=None):
+ self.embed_data = dict()
+ self.meta_data = dict()
+ if block_data_path is None:
+ args = get_args()
+ block_data_path = args.block_data_path
+ rank = args.rank
+ self.block_data_path = block_data_path
+ self.rank = rank
+
+ if load_from_path:
+ self.load_from_file()
+
+ block_data_name = os.path.splitext(self.block_data_path)[0]
+ self.temp_dir_name = block_data_name + '_tmp'
+
+ def state(self):
+ return {
+ 'embed_data': self.embed_data,
+ 'meta_data': self.meta_data,
+ }
+
+ def clear(self):
+ """Clear the embedding data structures to save memory.
+ The metadata ends up getting used, and is also much smaller in dimensionality
+ so it isn't really worth clearing.
+ """
+ self.embed_data = dict()
+
+ def load_from_file(self):
+ """Populate members from instance saved to file"""
+
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
+ print("\n> Unpickling BlockData", flush=True)
+ state_dict = pickle.load(open(self.block_data_path, 'rb'))
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
+ print(">> Finished unpickling BlockData\n", flush=True)
+
+ self.embed_data = state_dict['embed_data']
+ self.meta_data = state_dict['meta_data']
+
+ def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
+ """Add data for set of blocks
+ :param block_indices: 1D array of unique int ids for the blocks
+ :param block_embeds: 2D array of embeddings of the blocks
+ :param block_metas: 2D array of metadata for the blocks.
+ In the case of REALM this will be [start_idx, end_idx, doc_idx]
+ """
+ for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
+ if not allow_overwrite and idx in self.embed_data:
+ raise ValueError("Unexpectedly tried to overwrite block data")
+
+ self.embed_data[idx] = np.float16(embed)
+ self.meta_data[idx] = meta
+
+ def save_shard(self):
+ """Save the block data that was created this in this process"""
+ if not os.path.isdir(self.temp_dir_name):
+ os.makedirs(self.temp_dir_name, exist_ok=True)
+
+ # save the data for each shard
+ with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file:
+ pickle.dump(self.state(), data_file)
+
+ def merge_shards_and_save(self):
+ """Combine all the shards made using self.save_shard()"""
+ shard_names = os.listdir(self.temp_dir_name)
+ seen_own_shard = False
+
+ for fname in os.listdir(self.temp_dir_name):
+ shard_rank = int(os.path.splitext(fname)[0])
+ if shard_rank == self.rank:
+ seen_own_shard = True
+ continue
+
+ with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
+ data = pickle.load(f)
+ old_size = len(self.embed_data)
+ shard_size = len(data['embed_data'])
+
+ # add the shard's data and check to make sure there is no overlap
+ self.embed_data.update(data['embed_data'])
+ self.meta_data.update(data['meta_data'])
+ assert len(self.embed_data) == old_size + shard_size
+
+ assert seen_own_shard
+
+ # save the consolidated shards and remove temporary directory
+ with open(self.block_data_path, 'wb') as final_file:
+ pickle.dump(self.state(), final_file)
+ shutil.rmtree(self.temp_dir_name, ignore_errors=True)
+
+ print("Finished merging {} shards for a total of {} embeds".format(
+ len(shard_names), len(self.embed_data)), flush=True)
+
+
+class FaissMIPSIndex(object):
+ """Wrapper object for a BlockData which similarity search via FAISS under the hood"""
+ def __init__(self, embed_size, block_data=None, use_gpu=False):
+ self.embed_size = embed_size
+ self.block_data = block_data
+ self.use_gpu = use_gpu
+ self.id_map = dict()
+
+ self.block_mips_index = None
+ self._set_block_index()
+
+ def _set_block_index(self):
+ """Create a Faiss Flat index with inner product as the metric to search against"""
+ try:
+ import faiss
+ except ImportError:
+ raise Exception("Error: Please install faiss to use FaissMIPSIndex")
+
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
+ print("\n> Building index", flush=True)
+ self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
+
+ if self.use_gpu:
+ # create resources and config for GpuIndex
+ res = faiss.StandardGpuResources()
+ config = faiss.GpuIndexFlatConfig()
+ config.device = torch.cuda.current_device()
+ config.useFloat16 = True
+
+ self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
+ print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
+ else:
+ # CPU index supports IDs so wrap with IDMap
+ self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
+ print(">> Initialized index on CPU", flush=True)
+
+ # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
+ if self.block_data is not None:
+ self.add_block_embed_data(self.block_data)
+
+ def reset_index(self):
+ """Delete existing index and create anew"""
+ del self.block_mips_index
+
+ # reset the block data so that _set_block_index will reload it as well
+ if self.block_data is not None:
+ block_data_path = self.block_data.block_data_path
+ del self.block_data
+ self.block_data = BlockData(block_data_path)
+
+ self._set_block_index()
+
+ def add_block_embed_data(self, all_block_data):
+ """Add the embedding of each block to the underlying FAISS index"""
+
+ # this assumes the embed_data is a dict : {int: np.array<float>}
+ block_indices, block_embeds = zip(*all_block_data.embed_data.items())
+
+ # the embeddings have to be entered in as float32 even though the math internally is done with float16.
+ block_embeds_arr = np.float32(np.array(block_embeds))
+ block_indices_arr = np.array(block_indices)
+
+ # faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
+ if self.use_gpu:
+ for i, idx in enumerate(block_indices):
+ self.id_map[i] = idx
+
+ # we no longer need the embedding data since it's in the index now
+ all_block_data.clear()
+
+ if self.use_gpu:
+ self.block_mips_index.add(block_embeds_arr)
+ else:
+ self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
+
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
+ print(">>> Finished adding block data to index", flush=True)
+
+ def search_mips_index(self, query_embeds, top_k, reconstruct=True):
+ """Get the top-k blocks by the index distance metric.
+
+ :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
+ if False: return [num_queries x k] array of distances, and another for indices
+ """
+ query_embeds = np.float32(detach(query_embeds))
+
+ if reconstruct:
+ # get the vectors themselves
+ top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
+ return top_k_block_embeds
+
+ else:
+ # get distances and indices of closest vectors
+ distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
+ if self.use_gpu:
+ fresh_indices = np.zeros(block_indices.shape)
+ for i, j in itertools.product(block_indices.shape):
+ fresh_indices[i, j] = self.id_map[block_indices[i, j]]
+ block_indices = fresh_indices
+ return distances, block_indices
diff --git a/megatron_lm/megatron/data/samplers.py b/megatron_lm/megatron/data/samplers.py
new file mode 100644
index 0000000..2fbd070
--- /dev/null
+++ b/megatron_lm/megatron/data/samplers.py
@@ -0,0 +1,148 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Batch samplers that work with either random or sequential data samplers."""
+
+import torch
+from torch.utils import data
+
+
+class RandomSampler(data.sampler.Sampler):
+ """Based off of pytorch RandomSampler and DistributedSampler. Essentially
+ a RandomSampler, but this class lets the user set an epoch like
+ DistributedSampler Samples elements randomly. If without replacement, then
+ sample from a shuffled dataset. If with replacement, then user can
+ specify ``num_samples`` to draw.
+ Arguments:
+ data_source (Dataset): dataset to sample from
+ num_samples (int): number of samples to draw, default=len(dataset)
+ replacement (bool): samples are drawn with replacement if ``True``,
+ default=False
+ """
+
+ def __init__(self, data_source, replacement=False, num_samples=None):
+ self.data_source = data_source
+ self.replacement = replacement
+ self._num_samples = num_samples
+ self.epoch = -1
+
+ if self._num_samples is not None and replacement is False:
+ raise ValueError("With replacement=False, num_samples should not "
+ "be specified, since a random permute will be "
+ "performed.")
+
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+ raise ValueError("num_samples should be a positive integer "
+ "value, but got num_samples={}".format(
+ self.num_samples))
+ if not isinstance(self.replacement, bool):
+ raise ValueError("replacement should be a boolean value, but got "
+ "replacement={}".format(self.replacement))
+
+ @property
+ def num_samples(self):
+ # dataset size might change at runtime
+ if self._num_samples is None:
+ return len(self.data_source)
+ return self._num_samples
+
+ def __iter__(self):
+ n = len(self.data_source)
+ g = torch.Generator()
+ if self.epoch >= 0:
+ g.manual_seed(self.epoch)
+ if self.replacement:
+ return iter(torch.randint(high=n, size=(self.num_samples,),
+ dtype=torch.int64, generator=g).tolist())
+ return iter(torch.randperm(n, generator=g).tolist())
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+class DistributedBatchSampler(data.sampler.BatchSampler):
+ """Similar to normal implementation of distributed sampler, except
+ implementation is at the batch sampler level, instead of just the
+ sampler level. This allows wrapping of arbitrary data samplers
+ (sequential, random, WeightedRandomSampler, etc.) with this batch
+ sampler.
+
+ The `interleave` argument specifies how to distribute a batch. A value
+ of True combined with the above random sampler is equivalent to pytorch's
+ torch.utils.data.distributed.DistributedSampler.
+
+ For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
+ specifying True will result in the following samples for each gpu:
+ GPU0: [0,2,4,6] GPU1: [1,3,5,7]
+ specifying False will result in the following samples:
+ GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
+
+ def __init__(self, sampler, batch_size, drop_last, rank=-1,
+ world_size=2, wrap_last=False, interleave=False):
+ super(DistributedBatchSampler, self).__init__(sampler, batch_size,
+ drop_last)
+ if rank == -1:
+ assert False, 'should not be here'
+ rank = torch.distributed.get_rank()
+ self.rank = rank
+ self.world_size = world_size
+ self.sampler.wrap_around = 0
+ self.wrap_around = 0
+ self.wrap_last = wrap_last
+ self.start_iter = 0
+ self.interleave = interleave
+
+ def __iter__(self):
+ batch = []
+ i = 0
+ for idx in self.data_iterator(self.sampler, wrap_around=False):
+ batch.append(idx)
+ if len(batch) == self.batch_size:
+ tbatch = self._batch(batch)
+ if i >= self.start_iter:
+ yield tbatch
+ self.start_iter = 0
+ i += 1
+ batch = []
+ batch_len = len(batch)
+ if batch_len > 0 and not self.drop_last:
+ if self.wrap_last:
+ self.sampler.wrap_around -= (self.batch_size)
+ self.wrap_around += (len(batch))
+ self.wrap_around %= self.batch_size
+ yield self._batch(batch)
+ if self.wrap_last:
+ self.sampler.wrap_around += self.batch_size
+
+ def data_iterator(self, _iter, wrap_around=False):
+ """iterates through data and handles wrap around"""
+ for i, idx in enumerate(_iter):
+ if i < self.wrap_around % self.batch_size:
+ continue
+ if wrap_around:
+ self.wrap_around += 1
+ self.wrap_around %= self.batch_size
+ yield idx
+
+ def _batch(self, batch):
+ """extracts samples only pertaining to this worker's batch"""
+ if self.interleave:
+ return batch[self.rank:self.batch_size:self.world_size]
+ start = self.rank * self.batch_size // self.world_size
+ end = (self.rank + 1) * self.batch_size // self.world_size
+ return batch[start:end]
diff --git a/megatron_lm/megatron/data/test/test_indexed_dataset.py b/megatron_lm/megatron/data/test/test_indexed_dataset.py
new file mode 100644
index 0000000..9103c6d
--- /dev/null
+++ b/megatron_lm/megatron/data/test/test_indexed_dataset.py
@@ -0,0 +1,125 @@
+# This file isn't really a formal automated test, it's just a place to
+# put some code used during development and manual testing of
+# indexed_dataset.
+
+from megatron.data import indexed_dataset
+from megatron.tokenizer import build_tokenizer
+import argparse
+import os
+import sys
+
+import torch
+
+script_dir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(os.path.join(script_dir, "../../../"))
+
+
+def test_indexed_dataset(args):
+ ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
+ tokenizer = build_tokenizer(args)
+ print(len(ds.doc_idx))
+ print(len(ds))
+ print(ds.doc_idx[-1])
+ if ds.supports_prefetch:
+ # just prefetch the whole thing in test (so assume it is small)
+ ds.prefetch(range(len(ds)))
+ if args.count > len(ds.doc_idx) - 1:
+ args.count = len(ds.doc_idx) - 1
+
+ for i in range(args.count):
+ start = ds.doc_idx[i]
+ end = ds.doc_idx[i + 1]
+ ids = ds[start:end]
+ print(f"Document {i}:")
+ print("--------------")
+ for s in ids:
+ assert len(s) > 0
+ l = s.data.tolist()
+ text = tokenizer.detokenize(l)
+ print(text)
+ print("---")
+
+
+def test_indexed_dataset_get(args):
+ ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
+ tokenizer = build_tokenizer(args)
+ size = ds.sizes[0]
+ print(f"size: {size}")
+ full = ds.get(0)
+ print(full)
+ # print(tokenizer.detokenize(full.data.tolist()))
+ print("---")
+ end = ds.get(0, offset=size - 10)
+ print(end)
+ # print(tokenizer.detokenize(end.data.tolist()))
+
+ start = ds.get(0, length=10)
+ print(start)
+ # print(tokenizer.detokenize(start.data.tolist()))
+
+ part = ds.get(0, offset=2, length=8)
+ print(part)
+ # print(tokenizer.detokenize(part.data.tolist()))
+
+# def test_albert_dataset(args):
+# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
+# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
+# # ds = AlbertDataset(idataset, tokenizer)
+# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
+# args.epochs, args.max_num_samples,
+# args.masked_lm_prob, args.seq_length,
+# args.short_seq_prob, args.seed)
+# truncated = 0
+# total = 0
+# for i, s in enumerate(ds):
+# ids = s['text']
+# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
+# print(tokens)
+# if i >= args.count-1:
+# exit()
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data', type=str, help='prefix to data files')
+ parser.add_argument('--dataset-impl', type=str, default='infer',
+ choices=['lazy', 'cached', 'mmap', 'infer'])
+ parser.add_argument('--count', type=int, default=10,
+ help='Number of samples/documents to print')
+
+ group = parser.add_argument_group(title='tokenizer')
+ group.add_argument('--tokenizer-type', type=str, required=True,
+ choices=['BertWordPieceLowerCase',
+ 'GPT2BPETokenizer'],
+ help='What type of tokenizer to use.')
+ group.add_argument('--vocab-file', type=str, default=None,
+ help='Path to the vocab file')
+ group.add_argument('--merge-file', type=str, default=None,
+ help='Path to the BPE merge file (if necessary).')
+
+ parser.add_argument('--epochs', type=int, default=5,
+ help='Number of epochs to plan for')
+ parser.add_argument('--max-num-samples', type=int, default=None,
+ help='Maximum number of samples to plan for')
+ parser.add_argument('--masked-lm-prob', type=float, default=0.15,
+ help='probability of masking tokens')
+ parser.add_argument('--seq-length', type=int, default=512,
+ help='maximum sequence length')
+ parser.add_argument('--short-seq-prob', type=float, default=0.1,
+ help='probability of creating a short sequence')
+ parser.add_argument('--seed', type=int, default=1234,
+ help='random seed')
+ args = parser.parse_args()
+ args.rank = 0
+ args.make_vocab_size_divisible_by = 128
+ args.model_parallel_size = 1
+
+ if args.dataset_impl == "infer":
+ args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
+
+# test_albert_dataset(args)
+ test_indexed_dataset_get(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/megatron_lm/megatron/data/test/test_preprocess_data.sh b/megatron_lm/megatron/data/test/test_preprocess_data.sh
new file mode 100755
index 0000000..d121c85
--- /dev/null
+++ b/megatron_lm/megatron/data/test/test_preprocess_data.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+IMPL=cached
+python ../preprocess_data.py \
+ --input test_samples.json \
+ --vocab vocab.txt \
+ --dataset-impl ${IMPL} \
+ --output-prefix test_samples_${IMPL} \
+ --workers 1 \
+ --log-interval 2
diff --git a/megatron_lm/megatron/deprecated_data_utils/__init__.py b/megatron_lm/megatron/deprecated_data_utils/__init__.py
new file mode 100644
index 0000000..abefedc
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/__init__.py
@@ -0,0 +1,141 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""utils for creating datasets"""
+import os
+import math
+
+import torch
+
+from .samplers import DistributedBatchSampler
+from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset
+from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader
+from .tokenization import Tokenization, CommandToken, Tokenizer, CharacterLevelTokenizer, BertWordPieceTokenizer, GPT2BPETokenizer, make_tokenizer
+from . import corpora
+
+TRAIN_DATA = 0
+VAL_DATA = 1
+TEST_DATA = 2
+
+
+def should_split(split):
+ """
+ given split proportions checks if should split
+ Examples:
+ >>> should_split([10,0,0])
+ False
+ >>> should_split([1,.1,.2])
+ True
+ """
+ return max(split) / sum(split) != 1.
+
+
+def get_ext(path):
+ """gets path extension"""
+ return os.path.splitext(path)[1]
+
+
+def get_dataset(path, **kwargs):
+ """gets dataset object based on keyword args and file at `path`"""
+ if supported_corpus(path):
+ return corpora.NAMED_CORPORA[path](**kwargs)
+ ext = get_ext(path)
+ if '.json' in ext:
+ text = json_dataset(path, **kwargs)
+ elif ext in ['.csv', '.tsv']:
+ text = csv_dataset(path, **kwargs)
+ else:
+ raise NotImplementedError('data file type %s is not supported' % (ext))
+ return text
+
+
+def supported_corpus(corpus_name):
+ """checks if corpus name is defined in `corpora.py`"""
+ return corpus_name in corpora.NAMED_CORPORA
+
+
+def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.],
+ delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
+ tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
+ model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
+ parallel_group=None, **kwargs):
+ """function to create datasets+tokenizers for common options"""
+ if isinstance(process_fn, str):
+ process_fn = eval(process_fn)
+ if non_binary_cols is not None:
+ # multilabel dataset support (only for csvs)
+ label_key = non_binary_cols
+
+ def get_dataset_from_path(path_):
+ if lazy:
+ # get lazily loaded dataset
+ named_corpora = False
+ if supported_corpus(path_):
+ named_corpora = True
+ name = path_
+ path_ = corpora.NAMED_CORPORA[path_].PATH
+ if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'):
+ # create cached version of dataset for lazy loading if it doesn't exist
+ text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
+ delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
+ make_lazy(path_, text.X, data_type='data')
+ # This should be a barrier but nccl barrier assumes
+ # device_index=rank which is not the case for model
+ # parallel case
+ counts = torch.cuda.LongTensor([1])
+ torch.distributed.all_reduce(counts, group=parallel_group)
+ assert counts[0].item() == torch.distributed.get_world_size(
+ group=parallel_group)
+
+ text = lazy_array_loader(path_, data_type='data', map_fn=process_fn)
+ else:
+ # get dataset
+ text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
+ delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn)
+ return text
+ # get one or multiple datasets and concatenate
+ if isinstance(path, str):
+ path = [path]
+ datasets = [get_dataset_from_path(p) for p in path]
+ if len(datasets) == 1:
+ ds = datasets[0]
+ else:
+ ds = ConcatDataset(datasets)
+ # make tokenizer for dataset
+ if tokenizer is None:
+ tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type,
+ pad_token, character_converage, **kwargs)
+
+ ds_type = ''
+ if 'ds_type' in kwargs:
+ ds_type = kwargs['ds_type']
+ ds.SetTokenizer(tokenizer)
+ # Split dataset into train/val/test (and wrap bert dataset)
+ if should_split(split):
+ ds = split_ds(ds, split)
+ if 'bert' in ds_type.lower():
+ presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
+ dstype = bert_sentencepair_dataset
+ ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
+ if d is not None else None for d in ds]
+ elif ds_type.lower() == 'gpt2':
+ ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
+ else:
+ if 'bert' in ds_type.lower():
+ presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
+ dstype = bert_sentencepair_dataset
+ ds = dstype(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
+ elif ds_type.lower() == 'gpt2':
+ ds = GPT2Dataset(ds, max_seq_len=seq_length)
+ return ds, tokenizer
diff --git a/megatron_lm/megatron/deprecated_data_utils/configure_data.py b/megatron_lm/megatron/deprecated_data_utils/configure_data.py
new file mode 100644
index 0000000..357c238
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/configure_data.py
@@ -0,0 +1,252 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""parses arguments and preps data loader"""
+
+import copy
+import torch
+
+from megatron import data_utils
+from megatron import mpu
+
+
+class DataConfig:
+
+ def __init__(self, defaults={}):
+ super(DataConfig, self).__init__()
+ self.defaults = defaults
+
+ def apply(self, args):
+ if torch.distributed.get_rank() == 0:
+ print('configuring data')
+ self.apply_defaults(args)
+ return make_loaders(args)
+
+ def set_defaults(self, **kwargs):
+ for k, v in kwargs.items():
+ self.defaults[k] = v
+
+ def apply_defaults(self, args):
+ for k, v in self.defaults.items():
+ k = k.replace('-', '_')
+ if not hasattr(args, k):
+ setattr(args, k, v)
+
+
+def make_data_loader(dataset, batch_size, args):
+
+ shuffle = args.shuffle
+ if shuffle:
+ sampler = data_utils.samplers.RandomSampler(
+ dataset, replacement=True, num_samples=batch_size * args.train_iters)
+ else:
+ sampler = torch.utils.data.SequentialSampler(dataset)
+ world_size = torch.distributed.get_world_size(
+ group=mpu.get_data_parallel_group())
+ rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
+ distributed = world_size > 1
+ drop_last = distributed
+
+ if distributed:
+ batch_sampler = data_utils.samplers.DistributedBatchSampler(sampler,
+ batch_size,
+ drop_last,
+ rank,
+ world_size)
+ else:
+ batch_sampler = torch.utils.data.BatchSampler(sampler,
+ batch_size,
+ drop_last)
+
+ data_loader = torch.utils.data.DataLoader(dataset,
+ batch_sampler=batch_sampler,
+ num_workers=args.num_workers,
+ pin_memory=True)
+
+ return data_loader
+
+
+def make_tfrecord_loaders(args):
+ """Load train/val/test dataset from shuffled TFRecords"""
+
+ import data_utils.tf_dl
+ data_set_args = {'batch_size': args.batch_size,
+ 'max_seq_len': args.seq_length,
+ 'max_preds_per_seq': args.max_preds_per_seq,
+ 'train': True,
+ 'num_workers': max(args.num_workers, 1),
+ 'seed': args.seed + args.rank + 1,
+ 'threaded_dl': args.num_workers > 0
+ }
+ train = data_utils.tf_dl.TFRecordDataLoader(args.train_data,
+ **data_set_args)
+ data_set_args['train'] = False
+ if args.eval_seq_length is not None:
+ data_set_args['max_seq_len'] = args.eval_seq_length
+ if args.eval_max_preds_per_seq is not None:
+ data_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq
+ valid = None
+ if args.valid_data is not None:
+ valid = data_utils.tf_dl.TFRecordDataLoader(args.valid_data,
+ **data_set_args)
+ test = None
+ if args.test_data is not None:
+ test = data_utils.tf_dl.TFRecordDataLoader(args.test_data,
+ **data_set_args)
+ tokenizer = data_utils.make_tokenizer(args.tokenizer_type,
+ train,
+ args.tokenizer_path,
+ args.vocab_size,
+ args.tokenizer_model_type,
+ cache_dir=args.cache_dir)
+
+ return (train, valid, test), tokenizer
+
+
+def make_loaders(args):
+ """makes training/val/test"""
+
+ if args.data_loader == 'tfrecords':
+ return make_tfrecord_loaders(args)
+ world_size = torch.distributed.get_world_size(
+ group=mpu.get_data_parallel_group())
+ batch_size = args.batch_size * world_size
+ eval_batch_size = batch_size
+ if args.eval_batch_size is not None:
+ eval_batch_size = args.eval_batch_size * world_size
+ seq_length = args.seq_length
+ if seq_length < 0:
+ seq_length = seq_length * world_size
+ eval_seq_length = args.eval_seq_length
+ if eval_seq_length is not None and eval_seq_length < 0:
+ eval_seq_length = eval_seq_length * world_size
+ split = get_split(args)
+ if args.data_path is not None:
+ args.train_data = args.data_path
+ data_set_args = {
+ 'path': args.train_data,
+ 'seq_length': seq_length,
+ 'lazy': args.data_loader == 'lazy',
+ 'delim': args.delim,
+ 'text_key': args.text_key,
+ 'label_key': 'label',
+ 'non_binary_cols': None,
+ 'ds_type': args.data_set_type,
+ 'split': split,
+ 'loose': args.loose_json,
+ 'tokenizer_type': args.tokenizer_type,
+ 'tokenizer_model_path': args.tokenizer_path,
+ 'vocab_size': args.vocab_size,
+ 'model_type': args.tokenizer_model_type,
+ 'cache_dir': args.cache_dir,
+ 'max_preds_per_seq': args.max_preds_per_seq,
+ 'presplit_sentences': args.presplit_sentences,
+ 'parallel_group': mpu.get_data_parallel_group()}
+
+ eval_set_args = copy.copy(data_set_args)
+ eval_set_args['split'] = [1.]
+ # if optional eval args were set then replace their
+ # equivalent values in the arg dict
+ if eval_seq_length:
+ eval_set_args['seq_length'] = eval_seq_length
+ if args.eval_max_preds_per_seq:
+ eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq
+ if args.eval_text_key is not None:
+ eval_set_args['text_key'] = args.eval_text_key
+
+ # make datasets splits and tokenizer
+ train = None
+ valid = None
+ test = None
+
+ if args.train_data is not None:
+ train, tokenizer = data_utils.make_dataset(**data_set_args)
+ if data_utils.should_split(split):
+ train, valid, test = train
+ eval_set_args['tokenizer'] = tokenizer
+
+ # make training and val dataset if necessary
+ if valid is None and args.valid_data is not None:
+ eval_set_args['path'] = args.valid_data
+ valid, tokenizer = data_utils.make_dataset(**eval_set_args)
+ eval_set_args['tokenizer'] = tokenizer
+ if test is None and args.test_data is not None:
+ eval_set_args['path'] = args.test_data
+ test, tokenizer = data_utils.make_dataset(**eval_set_args)
+
+ # wrap datasets with data loader
+ if train is not None and args.batch_size > 0:
+ train = make_data_loader(train, batch_size, args)
+ args.do_train = True
+ else:
+ args.do_train = False
+ eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
+ if valid is not None:
+ valid = make_data_loader(valid, eval_batch_size, args)
+ args.do_valid = True
+ else:
+ args.do_valid = False
+ if test is not None:
+ test = make_data_loader(test, eval_batch_size, args)
+ args.do_test = True
+ else:
+ args.do_test = False
+
+ return (train, valid, test), tokenizer
+
+
+def get_split(args):
+ """
+ Get dataset splits from comma separated string list
+ """
+ splits = []
+ if args.split.find(',') != -1:
+ splits = [float(s) for s in args.split.split(',')]
+ elif args.split.find('/') != -1:
+ splits = [float(s) for s in args.split.split('/')]
+ else:
+ splits = [float(args.split)]
+ split_total = sum(splits)
+ if split_total < 1.:
+ splits.append(1 - split_total)
+ while len(splits) < 3:
+ splits.append(0.)
+ splits = splits[:3]
+ if args.valid_data is not None:
+ splits[1] = 0.
+ if args.test_data is not None:
+ splits[2] = 0.
+ final_sum = sum(splits)
+ return [s / final_sum for s in splits]
+
+
+def configure_data():
+ """add cmdline flags for configuring datasets"""
+ # These are options that are used by data_utils, but are either
+ # deprecated or not meant to be exposed to the command line user.
+ # These options are intneded to be set in code by specific scripts.
+ defaults = {
+ 'world_size': 1,
+ 'rank': -1,
+ 'persist_state': 0,
+ 'lazy': False,
+ 'transpose': False,
+ 'data_set_type': 'supervised',
+ 'seq_length': 256,
+ 'eval_seq_length': 256,
+ 'samples_per_shard': 100
+ }
+
+ return DataConfig(defaults=defaults)
diff --git a/megatron_lm/megatron/deprecated_data_utils/corpora.py b/megatron_lm/megatron/deprecated_data_utils/corpora.py
new file mode 100755
index 0000000..73749d9
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/corpora.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""several datasets with preset arguments"""
+from .datasets import json_dataset, csv_dataset
+import os
+
+
+class wikipedia(json_dataset):
+ """
+ dataset for wikipedia with arguments configured for convenience
+
+ command line usage: `--train-data wikipedia`
+ """
+ PATH = 'data/wikipedia/wikidump_lines.json'
+ assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
+
+ def __init__(self, **kwargs):
+ assert os.path.exists(wikipedia.PATH), \
+ wikipedia.assert_str
+ if not kwargs:
+ kwargs = {}
+ kwargs['text_key'] = 'text'
+ kwargs['loose_json'] = True
+ super(wikipedia, self).__init__(wikipedia.PATH, **kwargs)
+
+
+class webtext(json_dataset):
+ """
+ dataset for webtext with arguments configured for convenience
+
+ command line usage: `--train-data webtext`
+ """
+ PATH = 'data/webtext/data.json'
+ assert_str = "make sure to set PATH for webtext data_utils/corpora.py"
+
+ def __init__(self, **kwargs):
+ assert os.path.exists(webtext.PATH), \
+ webtext.assert_str
+ if not kwargs:
+ kwargs = {}
+ kwargs['text_key'] = 'text'
+ kwargs['loose_json'] = True
+ super(webtext, self).__init__(webtext.PATH, **kwargs)
+
+
+NAMED_CORPORA = {
+ 'wikipedia': wikipedia,
+ 'webtext': webtext,
+}
diff --git a/megatron_lm/megatron/deprecated_data_utils/datasets.py b/megatron_lm/megatron/deprecated_data_utils/datasets.py
new file mode 100755
index 0000000..bf8ef8a
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/datasets.py
@@ -0,0 +1,883 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""dataset objects for jsons, csvs, and BERT datasets"""
+
+import os
+import time
+from operator import itemgetter
+from bisect import bisect_right
+import json
+import csv
+import math
+import random
+from itertools import accumulate
+
+from torch.utils import data
+import pandas as pd
+import numpy as np
+
+import nltk
+from nltk import tokenize
+
+from .lazy_loader import lazy_array_loader, exists_lazy, make_lazy
+from .tokenization import Tokenization
+
+
+class ConcatDataset(data.Dataset):
+ """
+ Dataset to concatenate multiple datasets.
+ Purpose: useful to assemble different existing datasets, possibly
+ large-scale datasets as the concatenation operation is done in an
+ on-the-fly manner.
+ Arguments:
+ datasets (sequence): List of datasets to be concatenated.
+ """
+
+ @staticmethod
+ def cumsum(sequence):
+ r, s = [], 0
+ for e in sequence:
+ l = len(e)
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, datasets, **kwargs):
+ super(ConcatDataset, self).__init__()
+ assert len(datasets) > 0, 'datasets should not be an empty iterable'
+ self.datasets = list(datasets)
+ self.is_lazy = sum([isinstance(ds, lazy_array_loader)
+ for ds in self.datasets]) == len(self.datasets)
+ self.cumulative_sizes = self.cumsum(self.datasets)
+ self._X = None
+ self._Y = None
+ self._lens = None
+
+ def SetTokenizer(self, tokenizer):
+ for ds in self.datasets:
+ ds.SetTokenizer(tokenizer)
+
+ def GetTokenizer(self):
+ return self.datasets[0].GetTokenizer()
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ dataset_idx = bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx]
+
+ @property
+ def lens(self):
+ if self._lens is None:
+ self._lens = []
+ if self.is_lazy:
+ for data in self.datasets:
+ self._lens.extend(data.lens)
+ else:
+ for data in self.datasets:
+ self._lens.extend([len(d['text']) if isinstance(
+ d, dict) else len(d) for d in data])
+ return self._lens
+
+ @property
+ def X(self):
+ if self._X is None:
+ self._X = []
+ for data in self.datasets:
+ self._X.extend(data.X)
+ return self._X
+
+ @property
+ def Y(self):
+ if self._Y is None:
+ self._Y = []
+ for data in self.datasets:
+ self._Y.extend(list(data.Y))
+ self._Y = np.array(self._Y)
+ return self._Y
+
+ @property
+ def cummulative_sizes(self):
+ warnings.warn("cummulative_sizes attribute is renamed to "
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
+ return self.cumulative_sizes
+
+
+class SplitDataset(data.Dataset):
+ """
+ Dataset wrapper to access a subset of another dataset.
+ Purpose: useful to index into existing datasets, possibly
+ large-scale datasets as the subindexing operation is done in an
+ on-the-fly manner.
+ Arguments:
+ ds (Dataset or array-like): List of datasets to be subindexed
+ split_inds (1D array-like): List of indices part of subset
+ """
+
+ def __init__(self, ds, split_inds, **kwargs):
+ self.split_inds = list(split_inds)
+ self.wrapped_data = ds
+ self.is_lazy = isinstance(ds, lazy_array_loader) or (hasattr(ds, 'is_lazy') and ds.is_lazy)
+ if self.is_lazy:
+ self.lens = itemgetter(*self.split_inds)(list(self.wrapped_data.lens))
+ self._X = None
+ self._Y = None
+
+ def __len__(self):
+ return len(self.split_inds)
+
+ def __getitem__(self, index):
+ return self.wrapped_data[self.split_inds[index]]
+
+ def SetTokenizer(self, tokenizer):
+ self.wrapped_data.SetTokenizer(tokenizer)
+
+ def GetTokenizer(self):
+ return self.wrapped_data.GetTokenizer()
+
+ @property
+ def X(self):
+ if self._X is None:
+ self._X = itemgetter(*self.split_inds)(self.wrapped_data.X)
+ return self._X
+
+ @property
+ def Y(self):
+ if self._Y is None:
+ self._Y = np.array(itemgetter(*self.split_inds)(self.wrapped_data.Y))
+ return self._Y
+
+ def __iter__(self):
+ for idx in self.split_inds:
+ yield self.wrapped_data[idx]
+
+
+def split_ds(ds, split=[.8, .2, .0], shuffle=True):
+ """
+ Split a dataset into subsets given proportions of how
+ much to allocate per split. If a split is 0% returns None for that split.
+ Purpose: Useful for creating train/val/test splits
+ Arguments:
+ ds (Dataset or array-like): Data to be split.
+ split (1D array-like): proportions to split `ds`. `sum(splits) != 0`
+ shuffle (boolean): Randomly split dataset. Default: True
+ """
+ split_sum = sum(split)
+ if split_sum == 0:
+ raise Exception('Split cannot sum to 0.')
+ split = np.array(split)
+ split /= split_sum
+ ds_len = len(ds)
+ inds = np.arange(ds_len)
+ if shuffle:
+ np.random.shuffle(inds)
+ start_idx = 0
+ residual_idx = 0
+ rtn_ds = [None] * len(split)
+ for i, f in enumerate(split):
+ if f != 0:
+ proportion = ds_len * split[i]
+ residual_idx += proportion % 1
+ split_ = int(int(proportion) + residual_idx)
+ split_inds = inds[start_idx:start_idx + max(split_, 1)]
+ rtn_ds[i] = SplitDataset(ds, split_inds)
+ start_idx += split_
+ residual_idx %= 1
+ return rtn_ds
+
+
+class csv_dataset(data.Dataset):
+ """
+ Class for loading datasets from csv files.
+ Purpose: Useful for loading data for unsupervised modeling or transfer tasks
+ Arguments:
+ path (str): Path to csv file with dataset.
+ tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None
+ preprocess_fn (callable): Callable that process a string into desired format.
+ delim (str): delimiter for csv. Default: ','
+ binarize_sent (bool): binarize label values to 0 or 1 if they\'re on a different scale. Default: False
+ drop_unlabeled (bool): drop rows with unlabelled values. Always fills remaining empty
+ columns with -1 (regardless if rows are dropped based on value) Default: False
+ text_key (str): key to get text from csv. Default: 'sentence'
+ label_key (str): key to get label from json dictionary. Default: 'label'
+ Attributes:
+ X (list): all strings from the csv file
+ Y (np.ndarray): labels to train with
+ """
+
+ def __init__(self, path, tokenizer=None, preprocess_fn=None, delim=',',
+ binarize_sent=False, drop_unlabeled=False, text_key='sentence', label_key='label',
+ **kwargs):
+ self.is_lazy = False
+ self.preprocess_fn = preprocess_fn
+ self.SetTokenizer(tokenizer)
+ self.path = path
+ self.delim = delim
+ self.text_key = text_key
+ self.label_key = label_key
+ self.drop_unlabeled = drop_unlabeled
+
+ if '.tsv' in self.path:
+ self.delim = '\t'
+
+ self.X = []
+ self.Y = []
+ try:
+ cols = [text_key]
+ if isinstance(label_key, list):
+ cols += label_key
+ else:
+ cols += [label_key]
+ data = pd.read_csv(self.path, sep=self.delim, usecols=cols, encoding='latin-1')
+ except BaseException:
+ data = pd.read_csv(self.path, sep=self.delim, usecols=[text_key], encoding='latin-1')
+
+ data = data.dropna(axis=0)
+
+ self.X = data[text_key].values.tolist()
+ try:
+ self.Y = data[label_key].values
+ except Exception as e:
+ self.Y = np.ones(len(self.X)) * -1
+
+ if binarize_sent:
+ self.Y = binarize_labels(self.Y, hard=binarize_sent)
+
+ def SetTokenizer(self, tokenizer):
+ if tokenizer is None:
+ self.using_tokenizer = False
+ if not hasattr(self, '_tokenizer'):
+ self._tokenizer = tokenizer
+ else:
+ self.using_tokenizer = True
+ self._tokenizer = tokenizer
+
+ def GetTokenizer(self):
+ return self._tokenizer
+
+ @property
+ def tokenizer(self):
+ if self.using_tokenizer:
+ return self._tokenizer
+ return None
+
+ def __len__(self):
+ return len(self.X)
+
+ def __getitem__(self, index):
+ """process+tokenize string and return string,label,and stringlen"""
+ x = self.X[index]
+ if self.tokenizer is not None:
+ x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn)
+ elif self.preprocess_fn is not None:
+ x = self.preprocess_fn(x)
+ y = self.Y[index]
+ if isinstance(y, str):
+ if self.tokenizer is not None:
+ y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn)
+ elif self.preprocess_fn is not None:
+ y = self.preprocess_fn(y)
+ return {'text': x, 'length': len(x), 'label': y}
+
+ def write(self, writer_gen=None, path=None, skip_header=False):
+ """
+ given a generator of metrics for each of the data points X_i,
+ write the metrics, text, and labels to a csv file
+ """
+ if path is None:
+ path = self.path + '.results'
+ print('generating csv at ' + path)
+ with open(path, 'w') as csvfile:
+ c = csv.writer(csvfile, delimiter=self.delim)
+ if writer_gen is not None:
+ # if first item of generator is a header of what the metrics mean then
+ # write header to csv file
+ if not skip_header:
+ header = (self.label_key,) + tuple(next(writer_gen)) + (self.text_key,)
+ c.writerow(header)
+ for i, row in enumerate(writer_gen):
+ row = (self.Y[i],) + tuple(row) + (self.X[i],)
+ c.writerow(row)
+ else:
+ c.writerow([self.label_key, self.text_key])
+ for row in zip(self.Y, self.X):
+ c.writerow(row)
+
+
+class json_dataset(data.Dataset):
+ """
+ Class for loading datasets from a json dump.
+ Purpose: Useful for loading data for unsupervised modeling or transfer tasks
+ Arguments:
+ path (str): path to json file with dataset.
+ tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None
+ preprocess_fn (callable): callable function that process a string into desired format.
+ Takes string, maxlen=None, encode=None as arguments. Default: process_str
+ text_key (str): key to get text from json dictionary. Default: 'sentence'
+ label_key (str): key to get label from json dictionary. Default: 'label'
+ Attributes:
+ all_strs (list): list of all strings from the dataset
+ all_labels (list): list of all labels from the dataset (if they have it)
+ """
+
+ def __init__(self, path, tokenizer=None, preprocess_fn=None, binarize_sent=False,
+ text_key='sentence', label_key='label', loose_json=False, **kwargs):
+ self.is_lazy = False
+ self.preprocess_fn = preprocess_fn
+ self.path = path
+ self.SetTokenizer(tokenizer)
+ self.X = []
+ self.Y = []
+ self.text_key = text_key
+ self.label_key = label_key
+ self.loose_json = loose_json
+
+ for j in self.load_json_stream(self.path):
+ s = j[text_key]
+ self.X.append(s)
+ self.Y.append(j[label_key])
+
+ if binarize_sent:
+ self.Y = binarize_labels(self.Y, hard=binarize_sent)
+
+ def SetTokenizer(self, tokenizer):
+ if tokenizer is None:
+ self.using_tokenizer = False
+ if not hasattr(self, '_tokenizer'):
+ self._tokenizer = tokenizer
+ else:
+ self.using_tokenizer = True
+ self._tokenizer = tokenizer
+
+ def GetTokenizer(self):
+ return self._tokenizer
+
+ @property
+ def tokenizer(self):
+ if self.using_tokenizer:
+ return self._tokenizer
+ return None
+
+ def __getitem__(self, index):
+ """gets the index'th string from the dataset"""
+ x = self.X[index]
+ if self.tokenizer is not None:
+ x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn)
+ elif self.preprocess_fn is not None:
+ x = self.preprocess_fn(x)
+ y = self.Y[index]
+ if isinstance(y, str):
+ if self.tokenizer is not None:
+ y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn)
+ elif self.preprocess_fn is not None:
+ y = self.preprocess_fn(y)
+ return {'text': x, 'length': len(x), 'label': y}
+
+ def __len__(self):
+ return len(self.X)
+
+ def write(self, writer_gen=None, path=None, skip_header=False):
+ """
+ given a generator of metrics for each of the data points X_i,
+ write the metrics, text, and labels to a json file
+ """
+ if path is None:
+ path = self.path + '.results'
+
+ jsons = []
+
+ if writer_gen is not None:
+ # if first item of generator is a header of what the metrics mean then
+ # write header to csv file
+ def gen_helper():
+ keys = {}
+ keys[0] = self.label_key
+ if not skip_header:
+ for idx, k in enumerate(tuple(next(writer_gen))):
+ keys[idx + 1] = k
+ for i, row in enumerate(writer_gen):
+ if i == 0 and skip_header:
+ for idx, _ in enumerate(row):
+ keys[idx + 1] = 'metric_%d' % (idx,)
+ j = {}
+ for idx, v in enumerate((self.Y[i],) + tuple(row)):
+ k = keys[idx]
+ j[k] = v
+ yield j
+ else:
+ def gen_helper():
+ for y in self.Y:
+ j = {}
+ j[self.label_key] = y
+ yield j
+
+ def out_stream():
+ for i, j in enumerate(gen_helper()):
+ j[self.text_key] = self.X[i]
+ yield j
+
+ self.save_json_stream(path, out_stream())
+
+ def save_json_stream(self, save_path, json_stream):
+ if self.loose_json:
+ with open(save_path, 'w') as f:
+ for i, j in enumerate(json_stream):
+ write_string = ''
+ if i != 0:
+ write_string = '\n'
+ write_string += json.dumps(j)
+ f.write(write_string)
+ else:
+ jsons = [j for j in json_stream]
+ json.dump(jsons, open(save_path, 'w'), separators=(',', ':'))
+
+ def load_json_stream(self, load_path):
+ if not self.loose_json:
+ jsons = json.load(open(load_path, 'r'))
+ generator = iter(jsons)
+ else:
+ def gen_helper():
+ with open(load_path, 'r') as f:
+ for row in f:
+ yield json.loads(row)
+ generator = gen_helper()
+
+ for j in generator:
+ if self.label_key not in j:
+ j[self.label_key] = -1
+ yield j
+
+
+class GPT2Dataset(data.Dataset):
+
+ def __init__(self, ds,
+ max_seq_len=1024,
+ num_samples=None,
+ weighted=True,
+ sample_across_doc=True,
+ random_across_doc_sampling=True,
+ bias_for_single_doc=False,
+ sentence_start=False, **kwargs):
+ self.ds = ds
+ self.ds_len = len(self.ds)
+ self.num_samples = num_samples
+ if num_samples is None:
+ self.num_samples = 1000 * self.ds_len
+ self.max_seq_len = max_seq_len
+ self.tokenizer = self.ds.GetTokenizer()
+ self.ds.SetTokenizer(None)
+ self.weighted = weighted
+ self.sample_across_doc = sample_across_doc
+ self.random_across_doc_sampling = random_across_doc_sampling
+ self.bias_for_single_doc = bias_for_single_doc
+ self.sentence_start = sentence_start
+ self.init_weighting()
+
+ def init_weighting(self):
+ if self.weighted:
+ if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
+ lens = np.array(self.ds.lens)
+ else:
+ lens = np.array([len(d['text']) if isinstance(d, dict)
+ else len(d) for d in self.ds])
+ self.total_len = np.sum(lens)
+ self.weighting = list(accumulate(lens))
+ else:
+ self.weighting = None
+
+ def get_weighted_samples(self, np_rng):
+ if self.weighting is not None:
+ idx = np_rng.randint(self.total_len)
+ return bisect_right(self.weighting, idx)
+ else:
+ return np_rng.randint(self.ds_len)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ # init rng
+ rng = random.Random(idx)
+ rng = np.random.RandomState(seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
+
+ # get possibly weighted random index from dataset
+ data_idx = self.get_weighted_samples(rng)
+# data_idx = rng.choice(self.ds_len, p=self.weighting)
+ tokens = self.getidx(data_idx)
+
+ # truncate or pad tokens
+ num_tokens = len(tokens)
+ if self.bias_for_single_doc:
+ tokens_to_strip = num_tokens - self.max_seq_len - 1
+ else:
+ tokens_to_strip = num_tokens - 1
+ if tokens_to_strip > 0:
+ strip_left_tokens = rng.randint(tokens_to_strip + 1)
+ tokens = tokens[strip_left_tokens:]
+ if self.sentence_start:
+ token_copy = list(tokens)
+ not_done = True
+ while (len(token_copy) > 0) and not_done:
+ tok = token_copy.pop(0)
+ if self.contains_sentence_end(tok):
+ tokens = token_copy
+ not_done = False
+ strip_right_rokens = len(tokens) - self.max_seq_len - 1
+ if strip_right_rokens > 0:
+ tokens = tokens[:-strip_right_rokens]
+
+ if self.sample_across_doc:
+ while (len(tokens) < (self.max_seq_len + 1)):
+ if self.random_across_doc_sampling:
+ data_idx = self.get_weighted_samples(rng)
+ else:
+ data_idx = (data_idx + 1) % self.ds_len
+ tokens += self.getidx(data_idx)
+ tokens = tokens[:(self.max_seq_len + 1)]
+
+ tokens = self.pad_seq(tokens)
+ return {'text': np.array(tokens), }
+
+ def getidx(self, data_idx):
+ data = self.ds[data_idx]
+ if isinstance(data, dict):
+ data = data['text']
+ # tokenize
+ tokenization = self.tokenizer.EncodeAsIds(data)
+ tokenization.append(self.tokenizer.get_command('eos'))
+ tokens = tokenization.tokenization
+ return tokens
+
+ def pad_seq(self, seq):
+ total_tokens = self.max_seq_len + 1
+ num_pad_tokens = max(0, total_tokens - len(seq))
+ seq += [self.tokenizer.get_command('pad').Id] * (num_pad_tokens)
+ return seq
+
+ def contains_sentence_end(self, tok):
+ tok = self.tokenizer.IdToToken(tok)
+ if '.' in tok:
+ return True
+ if '?' in tok:
+ return True
+ if '!' in tok:
+ return True
+ return False
+
+
+class bert_sentencepair_dataset(data.Dataset):
+ """
+ Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair.
+ Arguments:
+ ds (Dataset or array-like): data corpus to use for training
+ max_seq_len (int): maximum sequence length to use for a sentence pair
+ mask_lm_prob (float): proportion of tokens to mask for masked LM
+ max_preds_per_seq (int): Maximum number of masked tokens per sentence pair. Default: math.ceil(max_seq_len*mask_lm_prob/10)*10
+ short_seq_prob (float): Proportion of sentence pairs purposefully shorter than max_seq_len
+ dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
+
+ """
+
+ def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None,
+ short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True, **kwargs):
+ self.ds = ds
+ self.ds_len = len(self.ds)
+ self.tokenizer = self.ds.GetTokenizer()
+ self.vocab_words = list(self.tokenizer.text_token_vocab.values())
+ self.ds.SetTokenizer(None)
+ self.max_seq_len = max_seq_len
+ self.mask_lm_prob = mask_lm_prob
+ if max_preds_per_seq is None:
+ max_preds_per_seq = math.ceil(max_seq_len * mask_lm_prob / 10) * 10
+ self.max_preds_per_seq = max_preds_per_seq
+ self.short_seq_prob = short_seq_prob
+ self.dataset_size = dataset_size
+ if self.dataset_size is None:
+ self.dataset_size = self.ds_len * (self.ds_len - 1)
+ self.presplit_sentences = presplit_sentences
+ if not self.presplit_sentences:
+ nltk.download('punkt', download_dir="./nltk")
+ self.weighted = weighted
+ self.get_weighting()
+
+ def get_weighting(self):
+ if self.weighted:
+ if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
+ lens = np.array(self.ds.lens)
+ else:
+ lens = np.array([len(d['text']) if isinstance(d, dict) else len(d)
+ for d in self.ds])
+ self.total_len = np.sum(lens)
+ self.weighting = list(accumulate(lens))
+ else:
+ self.weighting = None
+
+ def get_weighted_samples(self, np_rng):
+ if self.weighting is not None:
+ idx = np_rng.randint(self.total_len)
+ return bisect_right(self.weighting, idx)
+ else:
+ return np_rng.randint(self.ds_len)
+
+ def __len__(self):
+ return self.dataset_size
+
+ def __getitem__(self, idx):
+ # get rng state corresponding to index (allows deterministic random pair)
+ rng = random.Random(idx)
+ np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
+ # get seq length
+ target_seq_length = self.max_seq_len
+ short_seq = False
+ if rng.random() < self.short_seq_prob:
+ target_seq_length = rng.randint(2, target_seq_length)
+ short_seq = True
+
+ # get sentence pair and label
+ is_random_next = None
+ lena = 0
+ lenb = 0
+ while (is_random_next is None) or (lena < 1) or (lenb < 1):
+ tokensa, tokensb, is_random_next = self.create_random_sentencepair(
+ target_seq_length, rng, np_rng)
+ lena = len(tokensa[0])
+ lenb = len(tokensb[0])
+
+ # truncate sentence pair to max_seq_len
+ tokensa, tokensb = self.truncate_seq_pair(tokensa, tokensb, self.max_seq_len, rng)
+ # join sentence pair, mask, and pad
+ tokens, mask, mask_labels, pad_mask = self.create_masked_lm_predictions(
+ tokensa, tokensb, self.mask_lm_prob, self.max_preds_per_seq, self.vocab_words, rng)
+ sample = {
+ 'text': np.array(
+ tokens[0]),
+ 'types': np.array(
+ tokens[1]),
+ 'is_random': int(is_random_next),
+ 'mask': np.array(mask),
+ 'mask_labels': np.array(mask_labels),
+ 'pad_mask': np.array(pad_mask)}
+ return sample
+
+ def sentence_split(self, document):
+ """split document into sentences"""
+ lines = document.split('\n')
+ if self.presplit_sentences:
+ return [line for line in lines if line]
+ rtn = []
+ for line in lines:
+ if line != '':
+ rtn.extend(tokenize.sent_tokenize(line))
+ return rtn
+
+ def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False):
+ """tokenize sentence and get token types"""
+ tokens = self.tokenizer.EncodeAsIds(sent).tokenization
+ str_type = 'str' + str(sentence_num)
+ token_types = [self.tokenizer.get_type(str_type).Id] * len(tokens)
+ return tokens, token_types
+
+ def get_doc(self, idx):
+ """gets text of document corresponding to idx"""
+ rtn = self.ds[idx]
+ if isinstance(rtn, dict):
+ rtn = rtn['text']
+ return rtn
+
+ def create_random_sentencepair(self, target_seq_length, rng, np_rng):
+ """
+ fetches a random sentencepair corresponding to rng state similar to
+ https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L248-L294
+ """
+ is_random_next = None
+
+ curr_strs = []
+ curr_str_types = []
+ curr_len = 0
+
+ while curr_len < 1:
+ curr_len = 0
+ doc_a = None
+ while doc_a is None:
+ if self.weighted:
+ # doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting)
+ doc_a_idx = self.get_weighted_samples(np_rng)
+ else:
+ doc_a_idx = rng.randint(0, self.ds_len - 1)
+ doc_a = self.sentence_split(self.get_doc(doc_a_idx))
+ if not doc_a:
+ doc_a = None
+
+ random_start_a = rng.randint(0, len(doc_a) - 1)
+ while random_start_a < len(doc_a):
+ sentence = doc_a[random_start_a]
+ sentence, sentence_types = self.sentence_tokenize(
+ sentence, 0, random_start_a == 0, random_start_a == len(doc_a))
+ curr_strs.append(sentence)
+ curr_str_types.append(sentence_types)
+ curr_len += len(sentence)
+ if random_start_a == len(doc_a) - 1 or curr_len >= target_seq_length:
+ break
+ random_start_a = (random_start_a + 1)
+
+ if curr_strs:
+ num_a = 1
+ if len(curr_strs) >= 2:
+ num_a = rng.randint(0, len(curr_strs))
+
+ tokens_a = []
+ token_types_a = []
+ for j in range(num_a):
+ tokens_a.extend(curr_strs[j])
+ token_types_a.extend(curr_str_types[j])
+
+ tokens_b = []
+ token_types_b = []
+ is_random_next = False
+ if len(curr_strs) == 1 or rng.random() < 0.5:
+ is_random_next = True
+ target_b_length = target_seq_length - len(tokens_a)
+ b_len = 0
+ while b_len < 1:
+ doc_b = None
+ while doc_b is None:
+ doc_b_idx = rng.randint(0, self.ds_len - 2)
+ doc_b_idx += int(doc_b_idx >= doc_a_idx)
+
+ doc_b = self.sentence_split(self.get_doc(doc_b_idx))
+ if not doc_b:
+ doc_b = None
+
+ random_start_b = rng.randint(0, len(doc_b) - 1)
+ while random_start_b < len(doc_b):
+ sentence_b = doc_b[random_start_b]
+ new_b_tokens, new_b_types = self.sentence_tokenize(
+ sentence_b, 1, random_start_b == 0, random_start_b == len(doc_b))
+ b_len += len(new_b_tokens)
+ tokens_b.extend(new_b_tokens)
+ token_types_b.extend(new_b_types)
+ if len(tokens_b) >= target_b_length:
+ break
+ random_start_b = (random_start_b + 1)
+ else:
+ is_random_next = False
+ for j in range(num_a, len(curr_strs)):
+ tokens_b.extend(curr_strs[j])
+ token_types_b.extend(curr_str_types[j])
+
+ return (tokens_a, token_types_a), (tokens_b, token_types_b), is_random_next
+
+ def truncate_seq_pair(self, a, b, max_seq_len, rng):
+ """
+ Truncate sequence pair according to original BERT implementation:
+ https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L391
+ """
+ tokens_a, token_types_a = a
+ tokens_b, token_types_b = b
+ max_num_tokens = self.calc_seq_len(max_seq_len)
+ # max_num_tokens = max_seq_len - 3
+ while True:
+ len_a = len(tokens_a)
+ len_b = len(tokens_b)
+ total_length = len_a + len_b
+ if total_length <= max_num_tokens:
+ break
+ if len(tokens_a) > len(tokens_b):
+ trunc_tokens = tokens_a
+ trunc_types = token_types_a
+ else:
+ trunc_tokens = tokens_b
+ trunc_types = token_types_b
+
+ assert len(trunc_tokens) >= 1
+
+ if rng.random() < 0.5:
+ trunc_tokens.pop(0)
+ trunc_types.pop(0)
+ else:
+ trunc_tokens.pop()
+ trunc_types.pop()
+ return (tokens_a, token_types_a), (tokens_b, token_types_b)
+
+ def calc_seq_len(self, max_seq_len):
+ return max_seq_len - 3
+
+ def mask_token(self, idx, tokens, types, vocab_words, rng):
+ """
+ helper function to mask `idx` token from `tokens` according to
+ section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf
+ """
+ label = tokens[idx]
+ if rng.random() < 0.8:
+ new_label = self.tokenizer.get_command('MASK').Id
+ else:
+ if rng.random() < 0.5:
+ new_label = label
+ else:
+ new_label = rng.choice(vocab_words)
+
+ tokens[idx] = new_label
+
+ return label
+
+ def pad_seq(self, seq):
+ """helper function to pad sequence pair"""
+ num_pad = max(0, self.max_seq_len - len(seq))
+ pad_mask = [0] * len(seq) + [1] * num_pad
+ seq += [self.tokenizer.get_command('pad').Id] * num_pad
+ return seq, pad_mask
+
+ def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b):
+ tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command(
+ 'sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
+ token_types = [token_types_a[0]] + token_types_a + \
+ [token_types_a[0]] + token_types_b + [token_types_b[0]]
+ return tokens, token_types
+
+ def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng):
+ """
+ Mask sequence pair for BERT training according to:
+ https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L338
+ """
+ tokens_a, token_types_a = a
+ tokens_b, token_types_b = b
+ tokens, token_types = self.concat_tokens(tokens_a, token_types_a, tokens_b, token_types_b)
+
+ len_a = len(tokens_a)
+ len_b = len(tokens_b)
+
+ cand_indices = [idx + 1 for idx in range(len_a)] + [idx + 2 + len_a for idx in range(len_b)]
+
+ rng.shuffle(cand_indices)
+
+ output_tokens, pad_mask = self.pad_seq(list(tokens))
+ output_types, _ = self.pad_seq(list(token_types))
+
+ num_to_predict = min(max_preds_per_seq, max(1, int(round(len(tokens) * mask_lm_prob))))
+
+ mask = [0] * len(output_tokens)
+ mask_labels = [-1] * len(output_tokens)
+
+ for idx in sorted(cand_indices[:num_to_predict]):
+ mask[idx] = 1
+ label = self.mask_token(idx, output_tokens, output_types, vocab_words, rng)
+ mask_labels[idx] = label
+
+ return (output_tokens, output_types), mask, mask_labels, pad_mask
diff --git a/megatron_lm/megatron/deprecated_data_utils/file_utils.py b/megatron_lm/megatron/deprecated_data_utils/file_utils.py
new file mode 100755
index 0000000..4dc7fdc
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/file_utils.py
@@ -0,0 +1,253 @@
+# This file is provided as is from:
+# https://github.com/huggingface/pytorch-pretrained-BERT
+# Please refer to their repository for copyright.
+
+"""
+Utilities for working with the local dataset cache.
+This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
+Copyright by the AllenNLP authors.
+"""
+from __future__ import (absolute_import, division, print_function, unicode_literals)
+
+import json
+import logging
+import os
+import shutil
+import tempfile
+from functools import wraps
+from hashlib import sha256
+import sys
+from io import open
+
+import boto3
+import requests
+from botocore.exceptions import ClientError
+from tqdm import tqdm
+
+try:
+ from urllib.parse import urlparse
+except ImportError:
+ from urlparse import urlparse
+
+try:
+ from pathlib import Path
+ PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
+ Path.home() / '.pytorch_pretrained_bert'))
+except (AttributeError, ImportError):
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
+ os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
+
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
+
+
+def url_to_filename(url, etag=None):
+ """
+ Convert `url` into a hashed filename in a repeatable way.
+ If `etag` is specified, append its hash to the url's, delimited
+ by a period.
+ """
+ url_bytes = url.encode('utf-8')
+ url_hash = sha256(url_bytes)
+ filename = url_hash.hexdigest()
+
+ if etag:
+ etag_bytes = etag.encode('utf-8')
+ etag_hash = sha256(etag_bytes)
+ filename += '.' + etag_hash.hexdigest()
+
+ return filename
+
+
+def filename_to_url(filename, cache_dir=None):
+ """
+ Return the url and etag (which may be ``None``) stored for `filename`.
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
+ """
+ if cache_dir is None:
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+
+ cache_path = os.path.join(cache_dir, filename)
+ if not os.path.exists(cache_path):
+ raise EnvironmentError("file {} not found".format(cache_path))
+
+ meta_path = cache_path + '.json'
+ if not os.path.exists(meta_path):
+ raise EnvironmentError("file {} not found".format(meta_path))
+
+ with open(meta_path, encoding="utf-8") as meta_file:
+ metadata = json.load(meta_file)
+ url = metadata['url']
+ etag = metadata['etag']
+
+ return url, etag
+
+
+def cached_path(url_or_filename, cache_dir=None):
+ """
+ Given something that might be a URL (or might be a local path),
+ determine which. If it's a URL, download the file and cache it, and
+ return the path to the cached file. If it's already a local path,
+ make sure the file exists and then return the path.
+ """
+ if cache_dir is None:
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
+ url_or_filename = str(url_or_filename)
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+
+ parsed = urlparse(url_or_filename)
+
+ if parsed.scheme in ('http', 'https', 's3'):
+ # URL, so get it from the cache (downloading if necessary)
+ return get_from_cache(url_or_filename, cache_dir)
+ elif os.path.exists(url_or_filename):
+ # File, and it exists.
+ return url_or_filename
+ elif parsed.scheme == '':
+ # File, but it doesn't exist.
+ raise EnvironmentError("file {} not found".format(url_or_filename))
+ else:
+ # Something unknown
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
+
+
+def split_s3_path(url):
+ """Split a full s3 path into the bucket name and path."""
+ parsed = urlparse(url)
+ if not parsed.netloc or not parsed.path:
+ raise ValueError("bad s3 path {}".format(url))
+ bucket_name = parsed.netloc
+ s3_path = parsed.path
+ # Remove '/' at beginning of path.
+ if s3_path.startswith("/"):
+ s3_path = s3_path[1:]
+ return bucket_name, s3_path
+
+
+def s3_request(func):
+ """
+ Wrapper function for s3 requests in order to create more helpful error
+ messages.
+ """
+
+ @wraps(func)
+ def wrapper(url, *args, **kwargs):
+ try:
+ return func(url, *args, **kwargs)
+ except ClientError as exc:
+ if int(exc.response["Error"]["Code"]) == 404:
+ raise EnvironmentError("file {} not found".format(url))
+ else:
+ raise
+
+ return wrapper
+
+
+@s3_request
+def s3_etag(url):
+ """Check ETag on S3 object."""
+ s3_resource = boto3.resource("s3")
+ bucket_name, s3_path = split_s3_path(url)
+ s3_object = s3_resource.Object(bucket_name, s3_path)
+ return s3_object.e_tag
+
+
+@s3_request
+def s3_get(url, temp_file):
+ """Pull a file directly from S3."""
+ s3_resource = boto3.resource("s3")
+ bucket_name, s3_path = split_s3_path(url)
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
+
+
+def http_get(url, temp_file):
+ req = requests.get(url, stream=True)
+ content_length = req.headers.get('Content-Length')
+ total = int(content_length) if content_length is not None else None
+ progress = tqdm(unit="B", total=total)
+ for chunk in req.iter_content(chunk_size=1024):
+ if chunk: # filter out keep-alive new chunks
+ progress.update(len(chunk))
+ temp_file.write(chunk)
+ progress.close()
+
+
+def get_from_cache(url, cache_dir=None):
+ """
+ Given a URL, look for the corresponding dataset in the local cache.
+ If it's not there, download it. Then return the path to the cached file.
+ """
+ if cache_dir is None:
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+
+ if not os.path.exists(cache_dir):
+ os.makedirs(cache_dir)
+
+ # Get eTag to add to filename, if it exists.
+ if url.startswith("s3://"):
+ etag = s3_etag(url)
+ else:
+ response = requests.head(url, allow_redirects=True)
+ if response.status_code != 200:
+ raise IOError("HEAD request failed for url {} with status code {}"
+ .format(url, response.status_code))
+ etag = response.headers.get("ETag")
+
+ filename = url_to_filename(url, etag)
+
+ # get cache path to put the file
+ cache_path = os.path.join(cache_dir, filename)
+
+ if not os.path.exists(cache_path):
+ # Download to temporary file, then copy to cache dir once finished.
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
+ with tempfile.NamedTemporaryFile() as temp_file:
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
+
+ # GET file object
+ if url.startswith("s3://"):
+ s3_get(url, temp_file)
+ else:
+ http_get(url, temp_file)
+
+ # we are copying the file before closing it, so flush to avoid truncation
+ temp_file.flush()
+ # shutil.copyfileobj() starts at the current position, so go to the start
+ temp_file.seek(0)
+
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
+ with open(cache_path, 'wb') as cache_file:
+ shutil.copyfileobj(temp_file, cache_file)
+
+ logger.info("creating metadata file for %s", cache_path)
+ meta = {'url': url, 'etag': etag}
+ meta_path = cache_path + '.json'
+ with open(meta_path, 'w', encoding="utf-8") as meta_file:
+ json.dump(meta, meta_file)
+
+ logger.info("removing temp file %s", temp_file.name)
+
+ return cache_path
+
+
+def read_set_from_file(filename):
+ '''
+ Extract a de-duped collection (set) of text from a file.
+ Expected file format is one item per line.
+ '''
+ collection = set()
+ with open(filename, 'r', encoding='utf-8') as file_:
+ for line in file_:
+ collection.add(line.rstrip())
+ return collection
+
+
+def get_file_extension(path, dot=True, lower=True):
+ ext = os.path.splitext(path)[1]
+ ext = ext if dot else ext[1:]
+ return ext.lower() if lower else ext
diff --git a/megatron_lm/megatron/deprecated_data_utils/lazy_loader.py b/megatron_lm/megatron/deprecated_data_utils/lazy_loader.py
new file mode 100644
index 0000000..506f529
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/lazy_loader.py
@@ -0,0 +1,202 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""utils for loading text from disk"""
+import os
+import mmap
+import pickle as pkl
+import time
+from itertools import accumulate
+
+import torch
+from torch.multiprocessing import Lock
+
+
+def get_lazy_path(path):
+ """
+ Gets directory path where lazy files are stored.
+ """
+ return os.path.splitext(path)[0] + '.lazy'
+
+
+def exists_lazy(path, data_type='data'):
+ """
+ Check if we've already made a lazy version of this file for the `data_type` field.
+ """
+ if not os.path.exists(get_lazy_path(path)):
+ return False
+ contents = os.listdir(get_lazy_path(path))
+ if data_type not in contents:
+ return False
+ if data_type + '.len.pkl' not in contents:
+ return False
+ return True
+
+
+def make_lazy(path, strs, data_type='data'):
+ """
+ Make lazy version of `data_type` field of the file. Byte offsets
+ corresponding to data indices are stored in a `.len.pkl` data file.
+ """
+ lazypath = get_lazy_path(path)
+ if not os.path.exists(lazypath):
+ os.makedirs(lazypath)
+ datapath = os.path.join(lazypath, data_type)
+ lenpath = os.path.join(lazypath, data_type + '.len.pkl')
+ if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+ with open(datapath, 'wb') as f:
+ str_lens = []
+ str_cnt = 0
+ for s in strs:
+ if isinstance(s, dict):
+ s = s['text']
+ encoded = s.encode('utf-8')
+ f.write(encoded)
+ str_cnt = len(encoded)
+ str_lens.append(str_cnt)
+ pkl.dump(str_lens, open(lenpath, 'wb'))
+ else:
+ while not os.path.exists(lenpath):
+ time.sleep(1)
+
+
+def split_strings(strings, start, chr_lens):
+ """
+ Split strings based on string lengths and given start.
+ """
+ return [strings[i - start:j - start] for i, j in zip([start] + chr_lens[:-1], chr_lens)]
+
+
+class ProcessorTokenizer:
+ """
+ callable class that runs a preprocessing, as well as tokenization step,
+ on input text.
+ """
+
+ def __init__(self, tokenizer, process_fn=None):
+ self.tokenizer = tokenizer
+ self.process_fn = process_fn
+
+ def __call__(self, string):
+ if self.tokenizer is not None:
+ string = self.tokenizer(string, process_fn=self.process_fn)
+ elif self.process_fn is not None:
+ string = self.process_fn(string)
+ return string
+
+
+class lazy_array_loader(object):
+ """
+ Arguments:
+ path: path to directory where array entries are concatenated into one big string file
+ and the .len file are located
+ data_type (str): Some datsets have multiple fields that are stored in different paths.
+ `data_type` specifies which of these fields to load in this class
+ mem_map (boolean): Specifies whether to memory map file `path`
+ map_fn (callable): Fetched strings are passed through map_fn before being returned.
+
+ Example of lazy loader directory structure:
+ file.json
+ file.lazy/
+ data_type1
+ data_type1.len.pkl
+ data_type2
+ data_type2.len.pkl
+ """
+
+ def __init__(self, path, data_type='data', mem_map=False, map_fn=None):
+ lazypath = get_lazy_path(path)
+ datapath = os.path.join(lazypath, data_type)
+ # get file where array entries are concatenated into one big string
+ self._file = open(datapath, 'rb', buffering=0)
+ self.file = self._file
+ # memory map file if necessary
+ self.mem_map = mem_map
+ if self.mem_map:
+ self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
+ lenpath = os.path.join(lazypath, data_type + '.len.pkl')
+ self.lens = pkl.load(open(lenpath, 'rb'))
+ self.ends = list(accumulate(self.lens))
+ self.dumb_ends = list(self.ends)
+ self.read_lock = Lock()
+ self.process_fn = map_fn
+ self.map_fn = map_fn
+ self._tokenizer = None
+
+ def SetTokenizer(self, tokenizer):
+ """
+ logic to set and remove (set to None) tokenizer.
+ combines preprocessing/tokenization into one callable.
+ """
+ if tokenizer is None:
+ if not hasattr(self, '_tokenizer'):
+ self._tokenizer = tokenizer
+ else:
+ self._tokenizer = tokenizer
+ self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn)
+
+ def GetTokenizer(self):
+ return self._tokenizer
+
+ def __getitem__(self, index):
+ """
+ read file and splice strings based on string ending array `self.ends`
+ """
+ if not isinstance(index, slice):
+ if index == 0:
+ start = 0
+ else:
+ start = self.ends[index - 1]
+ end = self.ends[index]
+ rtn = self.file_read(start, end)
+ if self.map_fn is not None:
+ return self.map_fn(rtn)
+ else:
+ # if slice, fetch strings with 1 diskread and then splice in memory
+ chr_lens = self.ends[index]
+ if index.start == 0 or index.start is None:
+ start = 0
+ else:
+ start = self.ends[index.start - 1]
+ stop = chr_lens[-1]
+ strings = self.file_read(start, stop)
+ rtn = split_strings(strings, start, chr_lens)
+ if self.map_fn is not None:
+ return self.map_fn([s for s in rtn])
+ return rtn
+
+ def __len__(self):
+ return len(self.ends)
+
+ def file_read(self, start=0, end=None):
+ """read specified portion of file"""
+
+ # atomic reads to avoid race conditions with multiprocess dataloader
+ self.read_lock.acquire()
+ # seek to start of file read
+ self.file.seek(start)
+ # read to end of file if no end point provided
+ if end is None:
+ rtn = self.file.read()
+ # else read amount needed to reach end point
+ else:
+ rtn = self.file.read(end - start)
+ self.read_lock.release()
+ # TODO: @raulp figure out mem map byte string bug
+ # if mem map'd need to decode byte string to string
+ rtn = rtn.decode('utf-8', 'ignore')
+ # rtn = str(rtn)
+ if self.mem_map:
+ rtn = rtn.decode('unicode_escape')
+ return rtn
diff --git a/megatron_lm/megatron/deprecated_data_utils/samplers.py b/megatron_lm/megatron/deprecated_data_utils/samplers.py
new file mode 100644
index 0000000..baa6b9d
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/samplers.py
@@ -0,0 +1,143 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""batch samplers that work with either random or sequential data samplers"""
+import math
+import os
+import sys
+
+import torch
+from torch.utils import data
+import numpy as np
+
+
+class RandomSampler(data.sampler.Sampler):
+ r"""
+ Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
+ but this class lets the user set an epoch like DistributedSampler
+ Samples elements randomly. If without replacement, then sample from a shuffled dataset.
+ If with replacement, then user can specify ``num_samples`` to draw.
+ Arguments:
+ data_source (Dataset): dataset to sample from
+ num_samples (int): number of samples to draw, default=len(dataset)
+ replacement (bool): samples are drawn with replacement if ``True``, default=False
+ """
+
+ def __init__(self, data_source, replacement=False, num_samples=None):
+ self.data_source = data_source
+ self.replacement = replacement
+ self._num_samples = num_samples
+ self.epoch = -1
+
+ if self._num_samples is not None and replacement is False:
+ raise ValueError("With replacement=False, num_samples should not be specified, "
+ "since a random permute will be performed.")
+
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+ raise ValueError("num_samples should be a positive integer "
+ "value, but got num_samples={}".format(self.num_samples))
+ if not isinstance(self.replacement, bool):
+ raise ValueError("replacement should be a boolean value, but got "
+ "replacement={}".format(self.replacement))
+
+ @property
+ def num_samples(self):
+ # dataset size might change at runtime
+ if self._num_samples is None:
+ return len(self.data_source)
+ return self._num_samples
+
+ def __iter__(self):
+ n = len(self.data_source)
+ g = torch.Generator()
+ if self.epoch >= 0:
+ g.manual_seed(self.epoch)
+ if self.replacement:
+ return iter(torch.randint(high=n, size=(self.num_samples,),
+ dtype=torch.int64, generator=g).tolist())
+ return iter(torch.randperm(n, generator=g).tolist())
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+class DistributedBatchSampler(data.sampler.BatchSampler):
+ """
+ similar to normal implementation of distributed sampler, except implementation is at the
+ batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
+ data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
+ """
+
+ def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False):
+ super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
+ if rank == -1:
+ assert False, 'should not be here'
+ rank = torch.distributed.get_rank()
+ self.rank = rank
+ self.world_size = world_size
+ self.sampler.wrap_around = 0
+ self.wrap_around = 0
+ self.wrap_last = wrap_last
+ self.start_iter = 0
+
+ def __iter__(self):
+ batch = []
+ last_batch = None
+ i = 0
+ for idx in self.data_iterator(self.sampler, wrap_around=False):
+ batch.append(idx)
+ if len(batch) == self.batch_size:
+ tbatch = self._batch(batch)
+ if i >= self.start_iter:
+ yield tbatch
+ self.start_iter = 0
+ i += 1
+ last_batch = np.array(list(tbatch))
+ batch = []
+ batch_len = len(batch)
+ if batch_len > 0 and not self.drop_last:
+ if self.wrap_last:
+ self.sampler.wrap_around -= (self.batch_size)
+ self.wrap_around += (len(batch))
+ self.wrap_around %= self.batch_size
+ if isinstance(self.sampler, TransposedSampler):
+ for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)):
+ if i == 0:
+ continue
+ batch.append(idx)
+ new_batch_len = len(batch)
+ if len(batch) == self.batch_size:
+ break
+ yield self._batch(batch)
+ if self.wrap_last:
+ self.sampler.wrap_around += self.batch_size
+
+ def data_iterator(self, _iter, wrap_around=False):
+ """iterates through data and handles wrap around"""
+ for i, idx in enumerate(_iter):
+ if i < self.wrap_around % self.batch_size:
+ continue
+ if wrap_around:
+ self.wrap_around += 1
+ self.wrap_around %= self.batch_size
+ yield idx
+
+ def _batch(self, batch):
+ """extracts samples only pertaining to this worker's batch"""
+ start = self.rank * self.batch_size // self.world_size
+ end = (self.rank + 1) * self.batch_size // self.world_size
+ return batch[start:end]
diff --git a/megatron_lm/megatron/deprecated_data_utils/scripts/presplit_sentences_json.py b/megatron_lm/megatron/deprecated_data_utils/scripts/presplit_sentences_json.py
new file mode 100644
index 0000000..f150f2f
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/scripts/presplit_sentences_json.py
@@ -0,0 +1,27 @@
+"""
+Usage:
+python scripts/presplit_sentences_json.py <original loose json file> <output loose json file>
+"""
+
+import sys
+import json
+
+import nltk
+
+nltk.download('punkt')
+
+input_file = sys.argv[1]
+output_file = sys.argv[2]
+
+line_seperator = "\n"
+
+with open(input_file, 'r') as ifile:
+ with open(output_file, "w") as ofile:
+ for doc in ifile.readlines():
+ parsed = json.loads(doc)
+ sent_list = []
+ for line in parsed['text'].split('\n'):
+ if line != '\n':
+ sent_list.extend(nltk.tokenize.sent_tokenize(line))
+ parsed['text'] = line_seperator.join(sent_list)
+ ofile.write(json.dumps(parsed) + '\n')
diff --git a/megatron_lm/megatron/deprecated_data_utils/scripts/split_gpt2_json.py b/megatron_lm/megatron/deprecated_data_utils/scripts/split_gpt2_json.py
new file mode 100644
index 0000000..e6ddb1b
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/scripts/split_gpt2_json.py
@@ -0,0 +1,141 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Takes a corpora of files (specified by `--input_files`) with json data separated
+by newlines (loose json). Splits data into train.json, val.json, test.json files
+under `output_dir`.
+
+Note: This code has the potential to override files with the names
+train.json, val.json, test.json in `--output_dir`.
+"""
+import os
+import argparse
+import math
+import random
+
+parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
+parser.add_argument('--input_files', nargs='+', required=True,
+ help='whitespace separated list of input data files')
+parser.add_argument('--output_dir', required=True,
+ help='output directory where to put files')
+parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
+ help='percentage of available data to use for val/test dataset')
+args = parser.parse_args()
+
+
+def get_lines(filepath):
+ lines = []
+ with open(filepath, 'r') as f:
+ for i, l in enumerate(f.readlines()):
+ l = l.strip()
+ lines.append(l)
+ return lines
+
+
+def get_splits(lines, line_counts):
+ all_lines = []
+ line_idx = []
+ file_mappings = []
+ for i, l in enumerate(lines):
+ all_lines.extend(l)
+ line_idx.extend(list(range(len(l))))
+ file_mappings.extend([i] * len(l))
+
+ indices = list(range(len(all_lines)))
+ random.shuffle(indices)
+ all_lines = [all_lines[idx] for idx in indices]
+ line_idx = [line_idx[idx] for idx in indices]
+ file_mappings = [file_mappings[idx] for idx in indices]
+
+ splits = []
+ mappings = []
+ start = 0
+ for end in line_counts:
+ end += start
+ splits.append(all_lines[start:end])
+ mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
+ start = end
+ return splits, mappings
+
+
+def format_mappings(line_idx, file_mappings):
+ lines = []
+ for m, l in zip(file_mappings, line_idx):
+ lines.append(str(m).strip() + '\t' + str(l).strip())
+ return lines
+
+
+def get_filepaths(filepaths, output_dir):
+ paths = []
+ train_path = 'train.json'
+ dev_path = 'dev.json'
+ test_path = 'test.json'
+ paths.append(os.path.join(output_dir, train_path))
+ paths.append(os.path.join(output_dir, dev_path))
+ paths.append(os.path.join(output_dir, test_path))
+ return paths
+
+
+def write_files(lines, mappings, filepaths):
+ for l, m, path in zip(lines, mappings, filepaths):
+ write_file(l, path)
+ write_mapping_file(m, path)
+
+
+def write_file(lines, path):
+ print('Writing:', path)
+ with open(path, 'w') as f:
+ for l in lines:
+ f.write(l + '\n')
+
+
+def write_mapping_file(m, path):
+ path = path + '.map'
+ m = [get_mapping_header()] + m
+ write_file(m, path)
+
+
+def get_mapping_header():
+ return 'file\tline #'
+
+
+if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+lines = []
+
+for filepath in args.input_files:
+ _lines = get_lines(filepath)
+ lines.append(_lines)
+
+# calculate number of lines to use for each
+line_counts = [len(l) for l in lines]
+total_lines = sum(line_counts)
+dev_percent = args.test_percent[0]
+dev_lines = math.ceil(dev_percent * total_lines)
+test_percent = 0
+if len(args.test_percent) == 2:
+ test_percent = args.test_percent[1]
+test_lines = math.ceil(test_percent * total_lines)
+train_lines = total_lines - (test_lines + dev_lines)
+normed_lines = [train_lines, dev_lines, test_lines]
+normed_lines = [int(l) for l in normed_lines]
+
+
+splits, mappings = get_splits(lines, normed_lines)
+filepaths = get_filepaths(args.input_files, args.output_dir)
+print('Writing output to:', filepaths)
+write_files(splits, mappings, filepaths)
diff --git a/megatron_lm/megatron/deprecated_data_utils/scripts/split_json.py b/megatron_lm/megatron/deprecated_data_utils/scripts/split_json.py
new file mode 100644
index 0000000..7d2958c
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/scripts/split_json.py
@@ -0,0 +1,126 @@
+"""
+Takes a corpora of files (specified by `--input_files`) with json data separated
+by newlines (loose json). Splits data into train.json, val.json, test.json files
+under `output_dir`.
+
+Note: This code has the potential to override files with the names
+train.json, val.json, test.json in `--output_dir`.
+"""
+import os
+import argparse
+import math
+import random
+
+parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
+parser.add_argument('--input_files', nargs='+', required=True,
+ help='whitespace separated list of input data files')
+parser.add_argument('--output_dir', required=True,
+ help='output directory where to put files')
+parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
+ help='percentage of available data to use for val/test dataset')
+args = parser.parse_args()
+
+
+def get_lines(filepath):
+ lines = []
+ with open(filepath, 'r') as f:
+ for i, l in enumerate(f.readlines()):
+ l = l.strip()
+ lines.append(l)
+ return lines
+
+
+def get_splits(lines, line_counts):
+ all_lines = []
+ line_idx = []
+ file_mappings = []
+ for i, l in enumerate(lines):
+ all_lines.extend(l)
+ line_idx.extend(list(range(len(l))))
+ file_mappings.extend([i] * len(l))
+
+ indices = list(range(len(all_lines)))
+ random.shuffle(indices)
+ all_lines = [all_lines[idx] for idx in indices]
+ line_idx = [line_idx[idx] for idx in indices]
+ file_mappings = [file_mappings[idx] for idx in indices]
+
+ splits = []
+ mappings = []
+ start = 0
+ for end in line_counts:
+ end += start
+ splits.append(all_lines[start:end])
+ mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
+ start = end
+ return splits, mappings
+
+
+def format_mappings(line_idx, file_mappings):
+ lines = []
+ for m, l in zip(file_mappings, line_idx):
+ lines.append(str(m).strip() + '\t' + str(l).strip())
+ return lines
+
+
+def get_filepaths(filepaths, output_dir):
+ paths = []
+ train_path = 'train.json'
+ dev_path = 'dev.json'
+ test_path = 'test.json'
+ paths.append(os.path.join(output_dir, train_path))
+ paths.append(os.path.join(output_dir, dev_path))
+ paths.append(os.path.join(output_dir, test_path))
+ return paths
+
+
+def write_files(lines, mappings, filepaths):
+ for l, m, path in zip(lines, mappings, filepaths):
+ write_file(l, path)
+ write_mapping_file(m, path)
+
+
+def write_file(lines, path):
+ print('Writing:', path)
+ with open(path, 'w') as f:
+ for l in lines:
+ f.write(l + '\n')
+
+
+def write_mapping_file(m, path):
+ path = path + '.map'
+ m = [get_mapping_header()] + m
+ write_file(m, path)
+
+
+def get_mapping_header():
+ return 'file\tline #'
+
+
+if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+lines = []
+
+for filepath in args.input_files:
+ _lines = get_lines(filepath)
+ lines.append(_lines)
+
+# calculate number of lines to use for each
+line_counts = [len(l) for l in lines]
+total_lines = sum(line_counts)
+dev_percent = args.test_percent[0]
+dev_lines = math.ceil(dev_percent * total_lines)
+test_percent = 0
+if len(args.test_percent) == 2:
+ test_percent = args.test_percent[1]
+test_lines = math.ceil(test_percent * total_lines)
+train_lines = total_lines - (test_lines + dev_lines)
+normed_lines = [train_lines, dev_lines, test_lines]
+normed_lines = [int(l) for l in normed_lines]
+
+
+splits, mappings = get_splits(lines, normed_lines)
+filepaths = get_filepaths(args.input_files, args.output_dir)
+print('Writing output to:', filepaths)
+write_files(splits, mappings, filepaths)
diff --git a/megatron_lm/megatron/deprecated_data_utils/tf_dl.py b/megatron_lm/megatron/deprecated_data_utils/tf_dl.py
new file mode 100755
index 0000000..7d93ab0
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/tf_dl.py
@@ -0,0 +1,129 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DataLoader for TFRecords"""
+
+import numpy as np
+import torch
+import queue
+import threading
+
+import tensorflow as tf
+tf.enable_eager_execution()
+
+
+class TFRecordDataLoader(object):
+ def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq,
+ train, num_workers=2, seed=1, threaded_dl=False):
+ assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
+ tf.set_random_seed(seed)
+ if isinstance(records, str):
+ records = [records]
+
+ self.record_converter = Record2Example({"input_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
+ "input_mask": tf.FixedLenFeature([max_seq_len], tf.int64),
+ "segment_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
+ "masked_lm_positions": tf.FixedLenFeature([max_preds_per_seq], tf.int64),
+ "masked_lm_ids": tf.FixedLenFeature([max_preds_per_seq], tf.int64),
+ "masked_lm_weights": tf.FixedLenFeature([max_preds_per_seq], tf.float32),
+ "next_sentence_labels": tf.FixedLenFeature([1], tf.int64)})
+
+ # Instantiate dataset according to original BERT implementation
+ if train:
+ self.dataset = tf.data.Dataset.from_tensor_slices(tf.constant(records))
+ self.dataset = self.dataset.repeat()
+ self.dataset = self.dataset.shuffle(buffer_size=len(records))
+
+ # use sloppy tfrecord dataset
+ self.dataset = self.dataset.apply(
+ tf.contrib.data.parallel_interleave(
+ tf.data.TFRecordDataset,
+ sloppy=train,
+ cycle_length=min(num_workers, len(records))))
+ self.dataset = self.dataset.shuffle(buffer_size=100)
+ else:
+ self.dataset = tf.data.TFRecordDataset(records)
+ self.dataset = self.dataset.repeat()
+
+ # Instantiate dataloader (do not drop remainder for eval)
+ loader_args = {'batch_size': batch_size,
+ 'num_parallel_batches': num_workers,
+ 'drop_remainder': train}
+ self.dataloader = self.dataset.apply(
+ tf.contrib.data.map_and_batch(
+ self.record_converter, **loader_args))
+ self.threaded_dl = threaded_dl
+ self.num_workers = num_workers
+
+ def __iter__(self):
+ if self.threaded_dl:
+ data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers))
+ for item in data_iter:
+ yield item
+ else:
+ data_iter = iter(self.dataloader)
+ for item in data_iter:
+ yield convert_tf_example_to_torch_tensors(item)
+
+
+class Record2Example(object):
+ def __init__(self, feature_map):
+ self.feature_map = feature_map
+
+ def __call__(self, record):
+ """Decodes a BERT TF record to a TF example."""
+ example = tf.parse_single_example(record, self.feature_map)
+ for k, v in list(example.items()):
+ if v.dtype == tf.int64:
+ example[k] = tf.to_int32(v)
+ return example
+
+
+def convert_tf_example_to_torch_tensors(example):
+ item = {k: (v.numpy()) for k, v in example.items()}
+ mask = np.zeros_like(item['input_ids'])
+ mask_labels = np.ones_like(item['input_ids']) * -1
+ for b, row in enumerate(item['masked_lm_positions'].astype(int)):
+ for i, idx in enumerate(row):
+ if item['masked_lm_weights'][b, i] != 0:
+ mask[b, idx] = 1
+ mask_labels[b, idx] = item['masked_lm_ids'][b, i]
+ output = {'text': item['input_ids'], 'types': item['segment_ids'], 'is_random': item['next_sentence_labels'],
+ 'pad_mask': 1 - item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
+ return {k: torch.from_numpy(v) for k, v in output.items()}
+
+
+class MultiprocessLoader(object):
+ def __init__(self, dataloader, num_workers=2):
+ self.dl = dataloader
+ self.queue_size = 2 * num_workers
+
+ def __iter__(self):
+ output_queue = queue.Queue(self.queue_size)
+ output_thread = threading.Thread(target=_multiproc_iter,
+ args=(self.dl, output_queue))
+ output_thread.daemon = True
+ output_thread.start()
+
+ while output_thread.is_alive():
+ yield output_queue.get(block=True)
+ else:
+ print(RuntimeError('TF record data loader thread exited unexpectedly'))
+
+
+def _multiproc_iter(dl, output_queue):
+ data_iter = iter(dl)
+ for item in data_iter:
+ tensors = convert_tf_example_to_torch_tensors(item)
+ output_queue.put(tensors, block=True)
diff --git a/megatron_lm/megatron/deprecated_data_utils/tokenization.py b/megatron_lm/megatron/deprecated_data_utils/tokenization.py
new file mode 100755
index 0000000..c37e6f1
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/tokenization.py
@@ -0,0 +1,922 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for using and training tokenizers (char, wordpiece, sentencepiece)"""
+from collections import namedtuple
+import random
+import os
+import csv
+import torch
+
+import nltk
+from nltk import tokenize as nltk_tokenize
+import sentencepiece as spm
+
+from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
+
+from .tokenization_gpt2 import GPT2Tokenizer
+import regex as re
+
+
+def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe',
+ pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
+ """
+ Helper function to instantiate a tokenizer given common combinations of options.
+ """
+ tokenizer_class = tokenizer_type
+ if isinstance(tokenizer_class, str):
+ tokenizer_class = eval(tokenizer_class)
+ if tokenizer_class is BertWordPieceTokenizer:
+ return BertWordPieceTokenizer(model_type, **kwargs)
+ elif tokenizer_class is GPT2BPETokenizer:
+ return GPT2BPETokenizer(**kwargs)
+ text_tokenizer = tokenizer_class(corpus=corpus, vocab_size=vocab_size, model_path=model_path, model_type=model_type,
+ pad_token=pad_token, character_coverage=character_coverage)
+ return Tokenizer(text_tokenizer, command_tokens, type_tokens)
+
+
+class Tokenization(object):
+ """
+ Tokenization object to hold tokenization, (processed text),and original
+ text. Can hold tokenization as Ids or tokens.
+
+ It also holds command tokens (pad, unk, etc.) for the tokenization.
+ This allows functions to pad/operate on tokenizations without having
+ access to the full tokenizer, just the tokenization.
+
+ Several standard array operations are implemented (insert, append, extend).
+ """
+
+ def __init__(self, tokenization, text=None, original_text=None,
+ command_tokens=None, asIds=True):
+ self.tokenization = tokenization
+ self.text = text
+ if self.text is None:
+ self.text = self.tokenization
+ self.original_text = original_text
+ if self.original_text is None:
+ self.original_text = self.text
+ self.command_tokens = command_tokens
+ self.asIds = asIds
+ self.parse_command_tokens()
+
+ def set_command_tokens(self, command_tokens):
+ self.command_tokens = command_tokens
+ return self.parse_command_tokens()
+
+ def parse_command_tokens(self):
+ if self.command_tokens is None:
+ return
+ for command_token in self.command_tokens:
+ if self.asIds:
+ setattr(self, command_token.name, command_token.Id)
+ else:
+ setattr(self, command_token.name, command_token.token)
+
+ def __getitem__(self, index):
+ return self.tokenization[index]
+
+ def __len__(self):
+ return len(self.tokenization)
+
+ def insert(self, idx, other):
+ if isinstance(other, (CommandToken, TypeToken)):
+ self.tokenization.insert(idx, other.Id)
+ if idx == 0:
+ self.text = other.token + self.text
+ self.original_text = other.token + self.original_text
+ elif idx == len(self.tokenization) - 1:
+ self.text += other.token
+ self.original_text += other.token
+ elif isinstance(other, Tokenization):
+ self.tokenization = self.tokenization[:idx] + \
+ other.tokenization + self.tokenization[idx:]
+ else:
+ self.tokenization = self.tokenization[:idx] + \
+ other.tokenization + self.tokenization[idx:]
+
+ def append(self, other):
+ if isinstance(other, (CommandToken, TypeToken)):
+ self.tokenization.append(other.Id)
+ self.text += other.token
+ self.original_text += other.token
+ elif isinstance(other, Tokenization):
+ self.tokenization.extend(other.tokenization)
+ self.text += other.text
+ self.original_text += other.original_text
+ else:
+ self.tokenization.append(other)
+ return self
+
+ def extend(self, other):
+ if isinstance(other, (CommandToken, TypeToken)):
+ self.tokenization.append(other.Id)
+ self.text += other.token
+ self.original_text += other.token
+ elif isinstance(other, list) and isinstance(other[0], (CommandToken, TypeToken)):
+ self.tokenization.extend([o.Id for o in other])
+ self.text += [o.token for o in other]
+ self.original_text += [o.token for o in other]
+ elif isinstance(other, Tokenization):
+ self.tokenization.extend(other.tokenization)
+ self.text += other.text
+ self.original_text += other.original_text
+ else:
+ self.tokenization.extend(other)
+ return self
+
+
+"""define some default command tokens for the tokenizer to use"""
+token_format = "<{0}>"
+
+COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id'))
+
+
+def prep_command_tokens(tokenlist, token_format=token_format):
+ return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
+
+
+class CommandToken(object):
+ def __init__(self, name, token, Id):
+ self.name = name
+ self.token = token
+ self.Id = Id
+
+ def __str__(self):
+ return str(COMMAND_TUPLE(self.name, self.token, self.Id))
+
+
+DEFAULT_COMMAND_TOKENS = [
+ ('pad', 0),
+ ('eos', 1),
+ ('bos', 2),
+ ('unk', 3),
+ ('sep', 4),
+ ('L2R', 5),
+ ('ENC', 6),
+ ('MASK', 7),
+]
+DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS)
+
+"""define some default type tokens for bert training"""
+
+TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id'))
+
+
+def prep_type_tokens(tokenlist, token_format=token_format):
+ return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
+
+
+class TypeToken(object):
+ def __init__(self, name, token, Id):
+ self.name = name
+ self.token = token
+ self.Id = Id
+
+ def __str__(self):
+ return str(TYPE_TUPLE(self.name, self.token, self.Id))
+
+
+DEFAULT_TYPE_TOKENS = [
+ ('function', 0),
+ ('command', 1),
+ ('str0', 2),
+ ('str1', 3),
+ ('str2', 4),
+ ('embedding0', 5),
+ ('embedding1', 6),
+ ('embedding2', 7),
+ ('arg0', 8),
+ ('arg1', 9),
+ ('arg2', 10),
+]
+DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS)
+
+
+class Tokenizer(object):
+ """
+ Tokenizer object that handles text tokenization, command tokens, and type tokens.
+
+ Command tokens and text tokens are stored together in one mapping of size
+ `len(text_tokenizer)+len(command_tokens)`. Command tokens are stored as first
+ `len(command_tokens)` tokens. Token idx is stored at `idx+len(command_tokens)`.
+
+ Token types are stored in a separate mapping of size `len(type_tokens)`.
+ """
+
+ def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None):
+ # set text tokenizer
+ self.text_tokenizer = text_tokenizer
+ if not hasattr(self, 'num_text_tokens'):
+ self.num_text_tokens = len(self.text_tokenizer)
+
+ # set command tokens
+ if command_tokens is None:
+ command_tokens = DEFAULT_COMMAND_TOKENS
+ self._command_tokens = command_tokens
+ self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+ self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+ self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+ if not hasattr(self, 'num_command_tokens'):
+ self.num_command_tokens = len(self._command_tokens)
+ if not hasattr(self, 'num_tokens'):
+ self.num_tokens = self.num_command_tokens + self.num_text_tokens
+
+ # set type tokens
+ if type_tokens is None:
+ type_tokens = DEFAULT_TYPE_TOKENS
+ self.type_tokens = type_tokens
+ self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+ self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+ self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+ if not hasattr(self, 'num_type_tokens'):
+ self.num_type_tokens = len(self.type_tokens)
+
+ # parse tokens and vocabs from tokenizer
+ self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens)
+ self._vocab = {t: Id for Id, t in self.command_id_map.items()}
+ self._vocab.update({t: Id + self.num_command_tokens for t,
+ Id in self.text_tokenizer.vocab.items()})
+
+ self._text_tokens = list(self.text_tokenizer.tokens)
+ self._text_token_vocab = {
+ t: Id + self.num_command_tokens for t,
+ Id in self.text_tokenizer.vocab.items()}
+
+ self._command_token_tokens = list(self.command_token_map.keys())
+ self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+ self._token_types = list(self.type_token_map.keys())
+ self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+ def __call__(self, text, process_fn=None):
+ """run preprocessing and encode text as Ids"""
+ return self.EncodeAsIds(text, process_fn=process_fn)
+
+ def __len__(self):
+ """total number of tokens"""
+ return self.num_tokens
+
+ def get_command(self, name):
+ """get command token corresponding to `name`"""
+ return self.command_name_map[name]
+
+ def get_type(self, name):
+ """get type token corresponding to `name`"""
+ return self.type_name_map[name]
+
+ @property
+ def tokens(self):
+ """list (or iterable) of all tokens for tokenizer"""
+ return self._tokens
+
+ @property
+ def vocab(self):
+ """dictionary mapping tokens to ids for tokenizer"""
+ return self._vocab
+
+ @property
+ def token_types(self):
+ """list (or iterable) of all token types for tokenizer"""
+ return self._token_types
+
+ @property
+ def token_type_vocab(self):
+ """dictionary mapping token types to ids for tokenizer"""
+ return self._token_type_vocab
+
+ @property
+ def command_tokens(self):
+ """list (or iterable) of all command tokens for tokenizer"""
+ return self._command_token_tokens
+
+ @property
+ def command_token_vocab(self):
+ """dictionary mapping command tokens to ids for tokenizer"""
+ return self._command_token_vocab
+
+ @property
+ def text_tokens(self):
+ """list (or iterable) of text tokens for text tokenizer"""
+ return self._text_tokens
+
+ @property
+ def text_token_vocab(self):
+ """dictionary mapping text tokens to ids for text tokenizer"""
+ return self._text_token_vocab
+
+ def EncodeAsIds(self, text, process_fn=None):
+ """
+ encode text using text tokenizer and shift Id values for command tokens
+ """
+ tokenization = self.text_tokenizer.EncodeAsIds(text, process_fn=process_fn)
+ tokenization.tokenization = [t + self.num_command_tokens for t in tokenization.tokenization]
+ tokenization.set_command_tokens(self._command_tokens)
+ return tokenization
+
+ def EncodeAsTokens(self, text, process_fn=None):
+ """
+ encode text as tokens using text tokenizer
+ """
+ tokenization = self.text_tokenizer.EncodeAsTokens(text, process_fn=process_fn)
+ tokenization.set_command_tokens(self._command_tokens)
+ return tokenization
+
+ def IdToToken(self, Id, type_token=False):
+ """convert Id to token accounting for command and type tokens"""
+ if isinstance(Id, (TypeToken, CommandToken)):
+ return Id.token
+ if type_token:
+ return self.type_id_map[Id].token
+ if Id < self.num_command_tokens:
+ return self.command_id_map[Id].token
+ return self.text_tokenizer.IdToToken(Id - self.num_command_tokens)
+
+ def TokenToId(self, token, type_token=False):
+ """convert token to Id accounting for command and type tokens"""
+ if isinstance(token, (TypeToken, CommandToken)):
+ return token.Id
+ if type_token:
+ return self.type_token_map[token].Id
+ if token in self.command_token_map:
+ return self.command_token_map[token].Id
+ return self.text_tokenizer.TokenToId(token) + self.num_command_tokens
+
+ def DecodeIds(self, Ids, type_token=False):
+ """
+ convert Ids to tokens accounting for command and type tokens, tokens
+ are joined and returned as a string.
+ """
+ if type_token:
+ return ' '.join(Id.token if isinstance(Id, TypeToken)
+ else self.type_id_map[Id].token for Id in Ids)
+ rtn_strs = []
+ current_str = []
+ if isinstance(Ids, Tokenization):
+ Ids = Ids.tokenization
+ for Id in Ids:
+ if isinstance(Id, CommandToken):
+ rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
+ current_str = []
+ rtn_strs.append(t.token)
+ elif Id < self.num_command_tokens:
+ rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
+ current_str = []
+ rtn_strs.append(self.command_id_map[Id].token)
+ else:
+ current_str.append(Id - self.num_command_tokens)
+ if current_str != []:
+ rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
+ return ' '.join(rtn_strs)
+
+ def DecodeTokens(self, Tokens, type_token=False):
+ """
+ convert tokens to a string accounting for command and type tokens.
+ """
+ if type_token:
+ return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+ rtn_strs = []
+ current_str = []
+ if isinstance(Tokens, Tokenization):
+ Tokens = Tokens.tokenization
+ for t in Tokens:
+ if isinstance(t, CommandToken):
+ rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
+ current_str = []
+ rtn_strs.append(t.token)
+ elif t in self.command_token_map:
+ rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
+ current_str = []
+ rtn_strs.append(t)
+ else:
+ current_str.append(t)
+ if current_str != []:
+ rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
+ return ' '.join(rtn_strs)
+
+
+class TextTokenizer(object):
+ """
+ Interface for text tokenizer
+ """
+
+ def __init__(self):
+ if not hasattr(self, 'num_text_tokens'):
+ self.num_text_tokens = 0
+ if not hasattr(self, 'num_tokens'):
+ self.num_tokens = self.num_text_tokens
+
+ def __call__(self, text, process_fn=None):
+ return self.EncodeAsIds(text, process_fn)
+
+ def __len__(self):
+ return self.num_text_tokens
+
+ @property
+ def tokens(self):
+ """list (or iterable) of text tokens for text tokenizer"""
+ raise NotImplementedError('TextTokenizer tokens property not implemented')
+
+ @property
+ def vocab(self):
+ """dictionary mapping tokens to ids"""
+ raise NotImplementedError('TextTokenizer vocab property not implemented')
+
+ @staticmethod
+ def exists(model_path):
+ """check if the filepath for a text tokenizer exists"""
+ raise NotImplementedError('TextTokenizer exists method not implemented')
+
+ def Train(self, corpus):
+ """train a tokenizer on a data corpus and save model for future use"""
+ raise NotImplementedError('TextTokenizer Train not implemented')
+
+ def EncodeAsIds(self, text, process_fn=None):
+ """
+ Preprocess text and encode as ids. Return a tokenization object with
+ original text, processed text, and id tokenization.
+ """
+ raise NotImplementedError('TextTokenizer EncodeAsIds not implemented')
+
+ def EncodeAsTokens(self, text, process_fn=None):
+ """
+ Preprocess text and encode as tokens. Return a tokenization object with
+ original text, processed text, and token tokenization.
+ """
+ raise NotImplementedError('TextTokenizer EncodeAsTokens not implemented')
+
+ def IdToToken(self, Id):
+ """Convert an Id to Token. Reverse lookup of self.vocab"""
+ raise NotImplementedError('TextTokenizer IdToToken not implemented')
+
+ def TokenToId(self, token):
+ """Convert a Token to Id. Lookup of self.vocab"""
+ raise NotImplementedError('TextTokenizer TokenToId not implemented')
+
+ def DecodeIds(self, Ids):
+ """Convert a list or tokenization object of Ids to a text string"""
+ raise NotImplementedError('TextTokenizer DecodeIds not implemented')
+
+ def DecodeTokens(self, Tokens):
+ """Convert a list or tokenization object of tokens to a text string"""
+ raise NotImplementedError('TextTokenizer DecodeTokens not implemented')
+
+
+class CharacterLevelTokenizer(TextTokenizer):
+ """
+ Text tokenizer for ASCII-256 Character Level Tokenization.
+ """
+
+ def __init__(self, **kwargs):
+ self.num_text_tokens = 256
+ super(CharacterLevelTokenizer, self).__init__()
+ self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)]
+ self._vocab = {t: i for i, t in enumerate(self._tokens)}
+
+ def __len__(self):
+ return 256
+
+ @staticmethod
+ def exists(model_path):
+ return True
+
+ def Train(self, corpus):
+ pass
+
+ @property
+ def tokens(self):
+ return self._tokens
+
+ @property
+ def vocab(self):
+ return self._vocab
+
+ def EncodeAsIds(self, text, process_fn=None):
+ """convert text to ascii 256 Ids"""
+ processed_text = text
+ if process_fn is not None:
+ processed_text = process_fn(processed_text)
+ processed_text = str(processed_text)
+ tokens = [self.TokenToId(c) for c in processed_text]
+ return Tokenization(tokens, processed_text, text)
+
+ def EncodeAsTokens(self, text, process_fn=None):
+ """convert text to ascii 256 characters"""
+ processed_text = text
+ if process_fn is not None:
+ processed_text = process_fn(processed_text)
+ processed_text = str(processed_text)
+ tokens = [c for c in processed_text]
+ return Tokenization(tokens, processed_text, text, asIds=False)
+
+ def IdToToken(self, Id):
+ """ascii index to character"""
+ return chr(Id)
+
+ def TokenToId(self, token):
+ """ascii character to index"""
+ return ord(token)
+
+ def DecodeIds(self, Ids):
+ """converts ascii ids to tokens before joining them into text"""
+ if isinstance(Ids, Tokenization):
+ Ids = Ids.tokenization
+ return ''.join([self.IdToToken(tok) for tok in Ids])
+
+ def DecodeTokens(self, Tokens):
+ """just concatenates ascii tokens into text"""
+ if isinstance(Tokens, Tokenization):
+ Tokens = Tokens.tokenization
+ return ''.join(Tokens)
+
+
+MAX_SENTENCEPIECE_SENTENCES = 100000000
+
+
+def get_corpus_freq(dataset, filepath, filetype='tsv'):
+ """
+ Take corpus, split it into sentences, and extract word frequencies.
+ Write frequencies to `filepath` as a tsv. Only write the first
+ MAX_SENTENCEPIECE_SENTENCES most common words to the file.
+ """
+ nltk.download('punkt', download_dir="./nltk")
+ if filetype == 'tsv':
+ delimiter = '\t'
+ else:
+ delimiter = ','
+
+ print("compute corpus frequency\n", flush=True)
+
+ total_sentence_count = 0
+ maxlen = 0
+ freqs = {}
+ for entry in dataset:
+ if isinstance(entry, dict):
+ entry = entry['text']
+ lines = entry.strip().split('\n')
+ for line in lines:
+ sentences = nltk_tokenize.sent_tokenize(line)
+ total_sentence_count += len(sentences)
+ for sentence in sentences:
+ maxlen = max(len(line), maxlen)
+ for word in sentence.split():
+ if word not in freqs:
+ freqs[word] = 0
+ freqs[word] += 1
+
+ print("length of freqs before truncating " + str(len(freqs)), flush=True)
+ print("file path for freq " + str(filepath), flush=True)
+
+ freqs_sorted = {}
+ counter = 0
+ for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True):
+ if counter >= MAX_SENTENCEPIECE_SENTENCES:
+ break
+ counter += 1
+ freqs_sorted[word] = count
+
+ print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True)
+
+ with open(filepath, 'w') as f:
+ writer = csv.writer(f, delimiter=delimiter)
+ for k, v in freqs_sorted.items():
+ writer.writerow([str(k), str(v)])
+
+ return total_sentence_count, maxlen
+
+
+class SentencePieceTokenizer(TextTokenizer):
+ """Trains and uses sentencepiece for text tokenization"""
+
+ def __init__(self, model_type='bpe', vocab_size=None, corpus=None,
+ model_path=None, character_coverage=1.0, **kwargs):
+ self.character_coverage = character_coverage
+ self.model_type = model_type.lower()
+ self.spm_model = model_path
+ self.num_text_tokens = vocab_size
+ make_train = not SentencePieceTokenizer.exists(self.spm_model)
+ if make_train:
+ assert corpus is not None and self.num_text_tokens is not None
+ self.Train(corpus, self.num_text_tokens)
+ self._tokens = []
+ self._vocab = {}
+ self.load_spm_model()
+ super(SentencePieceTokenizer, self).__init__()
+
+ def __len__(self):
+ return self.num_text_tokens
+
+ @property
+ def tokens(self):
+ return self._tokens
+
+ @property
+ def vocab(self):
+ return self._vocab
+
+ @staticmethod
+ def exists(model_path):
+ if model_path is None:
+ return False
+ # check if path exists
+ dne = not os.path.exists(model_path)
+ # check if path.model exists
+ if dne and not model_path.endswith('.model'):
+ dne = not os.path.exists(model_path + '.model')
+ return not dne
+
+ def load_spm_model(self):
+ """load sentencepiece model and parse vocab"""
+ if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'):
+ self.spm_model = self.spm_model + '.model'
+ self.sp = spm.SentencePieceProcessor()
+ self.sp.Load(self.spm_model)
+ self.vocab_size = self.num_text_tokens = len(self.sp)
+ self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)]
+ self._vocab = {t: i for i, t in enumerate(self._tokens)}
+
+ def Train(self, corpus, num_text_tokens):
+ """train sentencepiece model on corpus using word frequencies"""
+ self.num_text_tokens = num_text_tokens
+ use_model_path = self.spm_model
+ random_hash = str(random.randint(0, 2147483647))
+ if use_model_path is None:
+ use_model_path = random_hash
+ if use_model_path.endswith('.model'):
+ use_model_path = use_model_path[:use_model_path.rfind('.model')]
+ input_path = use_model_path + '.tsv.' + random_hash
+ line_count, maxlenline = get_corpus_freq(corpus, input_path)
+ line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES)
+ print('line count used as input_sentence_size ', line_count, flush=True)
+ print('training sentencepiece model', flush=True)
+ train_string = '--input={file_path} --model_prefix={model_prefix} --vocab_size={vocab_size}' \
+ + ' --model_type={model_type} --character_coverage={character_coverage} ' \
+ + '--input_sentence_size={input_sentence_size} ' \
+ + '--input_format=tsv'
+ train_string = train_string.format(file_path=input_path, model_prefix=use_model_path, vocab_size=num_text_tokens,
+ model_type=self.model_type, character_coverage=self.character_coverage,
+ input_sentence_size=int(line_count)) # , #)#,
+ print("calling spm.SentencePieceTrainer.Train(%s)" % (train_string), flush=True)
+ spm.SentencePieceTrainer.Train(train_string)
+ os.remove(input_path)
+ self.spm_model = use_model_path + '.model'
+ print('sentencepiece model written to ' + self.spm_model, flush=True)
+
+ def EncodeAsIds(self, text, process_fn=None):
+ """convert text to sentencepiece Ids"""
+ processed_text = text
+ if process_fn is not None:
+ processed_text = process_fn(processed_text)
+ tokens = self.sp.EncodeAsIds(processed_text)
+ return Tokenization(tokens, processed_text, text)
+
+ def EncodeAsTokens(self, text, process_fn=None):
+ """convert text to sentencepiece tokens"""
+ processed_text = text
+ if process_fn is not None:
+ processed_text = process_fn(processed_text)
+ tokens = self.sp.EncodeAsTokens(processed_text)
+ return Tokenization(tokens, processed_text, text, asIds=False)
+
+ def IdToToken(self, Id):
+ """convert Id to sentencpiece token"""
+ return self.sp.IdToPiece(Id)
+
+ def TokenToId(self, token):
+ """convert sentencpiece token to Id"""
+ return self.sp.PieceToId(token)
+
+ def DecodeIds(self, Ids):
+ """converts ids to a text string"""
+ if isinstance(Ids, Tokenization):
+ Ids = Ids.tokenization
+ return self.sp.DecodeIds(Ids)
+
+ def DecodeTokens(self, Tokens):
+ """converts sentencepiece tokens to a text string"""
+ if isinstance(Tokens, Tokenization):
+ Tokens = Tokens.tokenization
+ return self.sp.DecodeTokens(Tokens)
+
+
+class BertWordPieceTokenizer(Tokenizer):
+ """
+ Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization
+ in BERT training. Default to bert-large-uncased tokenizer.
+ """
+
+ def __init__(self, tokenizer_model_type=None, cache_dir=None, **kwargs):
+ # default to bert-large-uncased tokenizer
+ if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP:
+ tokenizer_model_type = 'bert-large-uncased'
+ if torch.distributed.get_rank() == 0:
+ print(
+ 'loading BertWordPieceTokenizer (',
+ tokenizer_model_type,
+ ') from cache_dir ',
+ cache_dir)
+ do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type)
+ self.text_tokenizer = BertTokenizer.from_pretrained(
+ tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir)
+ if torch.distributed.get_rank() == 0:
+ print('loaded', tokenizer_model_type)
+ # disable max len warnings by increasing max len
+ self.text_tokenizer.max_len = int(1e12)
+
+ # set command tokens from wordpiece tokenizer values
+ self.num_command_tokens = 5
+ self.num_tokens = len(self.text_tokenizer.vocab)
+ self.num_text_tokens = self.num_tokens - 5
+ self.num_type_tokens = 2
+
+ self._command_tokens = [
+ CommandToken('pad', '[PAD]', self.text_tokenizer.vocab['[PAD]']),
+ CommandToken('ENC', '[CLS]', self.text_tokenizer.vocab['[CLS]']),
+ CommandToken('MASK', '[MASK]', self.text_tokenizer.vocab['[MASK]']),
+ CommandToken('unk', '[UNK]', self.text_tokenizer.vocab['[UNK]']),
+ CommandToken('sep', '[SEP]', self.text_tokenizer.vocab['[SEP]']),
+ ]
+ self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+ self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+ self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+
+ # set type tokens
+ self.type_tokens = [
+ TypeToken('str0', '<str0>', 0),
+ TypeToken('str1', '<str1>', 1),
+ ]
+ self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+ self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+ self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+
+ # parse tokens and vocabs from tokenizer
+
+ self._tokens = list(self.text_tokenizer.vocab.keys())
+ self._vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
+
+ self._text_tokens = list(self._tokens)
+ self._text_token_vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
+
+ self._command_token_tokens = list(self.command_token_map.keys())
+ self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+ self._token_types = list(self.type_token_map.keys())
+ self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+ def EncodeAsIds(self, text, process_fn=None):
+ """convert text to wordpiece Ids"""
+ processed_text = text
+ if process_fn is not None:
+ processed_text = process_fn(processed_text)
+ tokens = self.text_tokenizer.tokenize(processed_text)
+ Ids = self.text_tokenizer.convert_tokens_to_ids(tokens)
+ return Tokenization(Ids, processed_text, text)
+
+ def EncodeAsTokens(self, text, process_fn=None):
+ """convert wordpiece token to Id"""
+ processed_text = text
+ if process_fn is not None:
+ processed_text = process_fn(processed_text)
+ tokens = self.text_tokenizer.tokenize(processed_text)
+ return Tokenization(tokens, processed_text, text, asIds=False)
+
+ def IdToToken(self, Id, type_token=False):
+ """convert Id to sentencpiece token"""
+ if isinstance(Id, (TypeToken, CommandToken)):
+ return Id.token
+ if type_token:
+ return self.type_id_map[Id].token
+ return self.text_tokenizer.ids_to_tokens[Id]
+
+ def TokenToId(self, token, type_token=False):
+ """convert sentencpiece token to Id"""
+ if isinstance(token, (TypeToken, CommandToken)):
+ return token.Id
+ if type_token:
+ return self.type_token_map[token].Id
+ return self.text_tokenizer.vocab[token]
+
+ def DecodeIds(self, Ids, type_token=False):
+ """converts ids to wordpiece tokens and joins them as a text string"""
+ if type_token:
+ return ' '.join(Id.token if isinstance(Id, TypeToken)
+ else self.type_id_map[Id].token for Id in Ids)
+ if isinstance(Ids, Tokenization):
+ Ids = Ids.tokenization
+ Tokens = []
+ for Id in Ids:
+ Tokens.append(self.text_tokenizer.ids_to_tokens[Id] if Id != -1 else '-1')
+ Tokens = self.text_tokenizer.convert_ids_to_tokens(Ids)
+ return ' '.join(Tokens)
+
+ def DecodeTokens(self, Tokens, type_token=False):
+ """converts wordpiece tokens to a text string"""
+ if type_token:
+ return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+ if isinstance(Tokens, Tokenization):
+ Tokens = Tokens.tokenization
+ return ' '.join(Tokens)
+
+
+class GPT2BPETokenizer(Tokenizer):
+ def __init__(self, cache_dir=None, **kwargs):
+ self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
+ cache_dir=cache_dir)
+
+ # disable max len warnings by increasing max len
+ self.text_tokenizer.max_len = int(1e12)
+ self.num_command_tokens = 2
+ self.num_tokens = len(self.text_tokenizer.encoder)
+ self.num_text_tokens = self.num_tokens - 1
+ self.num_type_tokens = 2
+
+ self._command_tokens = [
+ CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']),
+ CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']),
+ ]
+ self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+ self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+ self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+
+ self.type_tokens = [
+ TypeToken('str0', '<str0>', 0),
+ TypeToken('str1', '<str1>', 1),
+ ]
+ self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+ self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+ self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+
+ self._tokens = list(self.text_tokenizer.encoder.keys())
+ self._vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
+
+ self._text_tokens = list(self._tokens)
+ self._text_token_vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
+
+ self._command_token_tokens = list(self.command_token_map.keys())
+ self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+ self._token_types = list(self.type_token_map.keys())
+ self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+ def EncodeAsIds(self, text, process_fn=None):
+ processed_text = text
+ if process_fn is not None:
+ processed_text = process_fn(processed_text)
+ Ids = self.text_tokenizer.encode(processed_text)
+ # return Tokenization(Ids, processed_text, text)
+ tokenization = Tokenization(Ids, processed_text, text)
+ tokenization.set_command_tokens(self._command_tokens)
+ return tokenization
+
+ def EncodeAsTokens(self, text, process_fn=None):
+ processed_text = text
+ if process_fn is not None:
+ processed_text = process_fn(processed_text)
+ tokens = []
+ for token in re.findall(self.text_tokenizer.pat, processed_text):
+ token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8'))
+ tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' '))
+ tokenization = Tokenization(tokens, processed_text, text, asIds=False)
+ tokenization.set_command_tokens(self._command_tokens)
+ return tokenization
+ # return Tokenization(tokens, processed_text, text, asIds=False)
+
+ def IdToToken(self, Id, type_token=False):
+ if isinstance(Id, (TypeToken, CommandToken)):
+ return Id.token
+ if type_token:
+ return self.type_id_map[Id].token
+ return self.text_tokenizer.decoder[Id]
+
+ def TokenToId(self, token, type_token=False):
+ if isinstance(token, (TypeToken, CommandToken)):
+ return token.Id
+ if type_token:
+ return self.type_token_map[token].Id
+ return self.text_tokenizer.encoder[token]
+
+ def DecodeIds(self, Ids, type_token=False):
+ if type_token:
+ return ' '.join(Id.token if isinstance(Id, TypeToken)
+ else self.type_id_map[Id].token for Id in Ids)
+ if isinstance(Ids, Tokenization):
+ Ids = Ids.tokenization
+ return self.text_tokenizer.decode(Ids)
+
+ def DecodeTokens(self, Tokens, type_token=False):
+ if type_token:
+ return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+ if isinstance(Tokens, Tokenization):
+ Tokens = Tokens.tokenization
+ return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
diff --git a/megatron_lm/megatron/deprecated_data_utils/tokenization_gpt2.py b/megatron_lm/megatron/deprecated_data_utils/tokenization_gpt2.py
new file mode 100644
index 0000000..700fc99
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/tokenization_gpt2.py
@@ -0,0 +1,319 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+from __future__ import (absolute_import, division, print_function,
+ unicode_literals)
+
+import sys
+import json
+import logging
+import os
+import regex as re
+from io import open
+
+try:
+ from functools import lru_cache
+except ImportError:
+ # Just a dummy decorator to get the checks to run on python2
+ # because honestly I don't want to support a byte-level unicode BPE
+ # tokenizer on python 2 right now.
+ def lru_cache():
+ return lambda func: func
+
+from .file_utils import cached_path
+
+logger = logging.getLogger(__name__)
+
+PRETRAINED_VOCAB_ARCHIVE_MAP = {
+ 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
+}
+PRETRAINED_MERGES_ARCHIVE_MAP = {
+ 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
+}
+PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
+ 'gpt2': 1024,
+}
+VOCAB_NAME = 'vocab.json'
+MERGES_NAME = 'merges.txt'
+SPECIAL_TOKENS_NAME = 'special_tokens.txt'
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ _chr = unichr if sys.version_info[0] == 2 else chr
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
+ list(range(ord("®"), ord("ÿ") + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [_chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class GPT2Tokenizer(object):
+ """
+ GPT-2 BPE tokenizer. Peculiarities:
+ - Byte-level BPE
+ """
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
+ """
+ Instantiate a PreTrainedBertModel from a pre-trained model file.
+ Download and cache the pre-trained model file if needed.
+ """
+ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
+ vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
+ merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
+ special_tokens_file = None
+ else:
+ vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
+ merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
+ special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
+ if not os.path.exists(special_tokens_file):
+ special_tokens_file = None
+ else:
+ logger.info("loading special tokens file {}".format(special_tokens_file))
+ # redirect to the cache, if necessary
+ try:
+ resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
+ resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
+ except EnvironmentError:
+ logger.error(
+ "Model name '{}' was not found in model name list ({}). "
+ "We assumed '{}' was a path or url but couldn't find files {} and {} "
+ "at this path or url.".format(
+ pretrained_model_name_or_path,
+ ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
+ pretrained_model_name_or_path,
+ vocab_file, merges_file))
+ return None
+ if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
+ logger.info("loading vocabulary file {}".format(vocab_file))
+ logger.info("loading merges file {}".format(merges_file))
+ else:
+ logger.info("loading vocabulary file {} from cache at {}".format(
+ vocab_file, resolved_vocab_file))
+ logger.info("loading merges file {} from cache at {}".format(
+ merges_file, resolved_merges_file))
+ if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
+ # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
+ # than the number of positional embeddings
+ max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
+ kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
+ # Instantiate tokenizer.
+ if special_tokens_file and 'special_tokens' not in kwargs:
+ special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
+ else:
+ special_tokens = kwargs.pop('special_tokens', [])
+ tokenizer = cls(
+ resolved_vocab_file,
+ resolved_merges_file,
+ special_tokens=special_tokens,
+ *inputs,
+ **kwargs)
+ return tokenizer
+
+ def __init__(self, vocab_file, merges_file, errors='replace',
+ special_tokens=None, max_len=None):
+ self.max_len = max_len if max_len is not None else int(1e12)
+ self.encoder = json.load(open(vocab_file))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_data]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+
+ # Should haved added re.IGNORECASE so BPE merges can happen for
+ # capitalized versions of contractions
+ self.pat = re.compile(
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+ self.special_tokens = {}
+ self.special_tokens_decoder = {}
+ self.set_special_tokens(special_tokens)
+
+ def __len__(self):
+ return len(self.encoder) + len(self.special_tokens)
+
+ def set_special_tokens(self, special_tokens):
+ """ Add a list of additional tokens to the encoder.
+ The additional tokens are indexed starting from the last index of the
+ current vocabulary in the order of the `special_tokens` list.
+ """
+ if not special_tokens:
+ self.special_tokens = {}
+ self.special_tokens_decoder = {}
+ return
+ self.special_tokens = dict((tok, len(self.encoder) + i)
+ for i, tok in enumerate(special_tokens))
+ self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
+ logger.info("Special tokens {}".format(self.special_tokens))
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except BaseException:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def tokenize(self, text):
+ """ Tokenize a string. """
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ if sys.version_info[0] == 2:
+ token = ''.join(self.byte_encoder[ord(b)] for b in token)
+ else:
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def convert_tokens_to_ids(self, tokens):
+ """ Converts a sequence of tokens into ids using the vocab. """
+ ids = []
+ if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
+ if tokens in self.special_tokens:
+ return self.special_tokens[tokens]
+ else:
+ return self.encoder.get(tokens, 0)
+ for token in tokens:
+ if token in self.special_tokens:
+ ids.append(self.special_tokens[token])
+ else:
+ ids.append(self.encoder.get(token, 0))
+ if len(ids) > self.max_len:
+ logger.warning(
+ "Token indices sequence length is longer than the specified maximum "
+ " sequence length for this OpenAI GPT model ({} > {}). Running this"
+ " sequence through the model will result in indexing errors".format(
+ len(ids), self.max_len)
+ )
+ return ids
+
+ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
+ """Converts a sequence of ids in BPE tokens using the vocab."""
+ tokens = []
+ for i in ids:
+ if i in self.special_tokens_decoder:
+ if not skip_special_tokens:
+ tokens.append(self.special_tokens_decoder[i])
+ else:
+ tokens.append(self.decoder[i])
+ return tokens
+
+ def encode(self, text):
+ return self.convert_tokens_to_ids(self.tokenize(text))
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
+ return text
+
+ def save_vocabulary(self, vocab_path):
+ """Save the tokenizer vocabulary and merge files to a directory."""
+ if not os.path.isdir(vocab_path):
+ logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
+ return
+ vocab_file = os.path.join(vocab_path, VOCAB_NAME)
+ merge_file = os.path.join(vocab_path, MERGES_NAME)
+ special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
+
+ with open(vocab_file, 'w', encoding='utf-8') as f:
+ f.write(json.dumps(self.encoder, ensure_ascii=False))
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write(u'#version: 0.2\n')
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!".format(merge_file))
+ index = token_index
+ writer.write(' '.join(bpe_tokens) + u'\n')
+ index += 1
+
+ index = len(self.encoder)
+ with open(special_tokens_file, 'w', encoding='utf-8') as writer:
+ for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!".format(special_tokens_file))
+ index = token_index
+ writer.write(token + u'\n')
+ index += 1
+
+ return vocab_file, merge_file, special_tokens_file
diff --git a/megatron_lm/megatron/deprecated_data_utils/wordpiece.py b/megatron_lm/megatron/deprecated_data_utils/wordpiece.py
new file mode 100755
index 0000000..cc20443
--- /dev/null
+++ b/megatron_lm/megatron/deprecated_data_utils/wordpiece.py
@@ -0,0 +1,391 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py"""
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import collections
+import logging
+import os
+import unicodedata
+from io import open
+
+from .file_utils import cached_path
+
+logger = logging.getLogger(__name__)
+
+PRETRAINED_VOCAB_ARCHIVE_MAP = {
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
+}
+PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
+ 'bert-base-uncased': 512,
+ 'bert-large-uncased': 512,
+ 'bert-base-cased': 512,
+ 'bert-large-cased': 512,
+ 'bert-base-multilingual-uncased': 512,
+ 'bert-base-multilingual-cased': 512,
+ 'bert-base-chinese': 512,
+}
+VOCAB_NAME = 'vocab.txt'
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ index = 0
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ while True:
+ token = reader.readline()
+ if not token:
+ break
+ token = token.strip()
+ vocab[token] = index
+ index += 1
+ return vocab
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class BertTokenizer(object):
+ """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
+
+ def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
+ never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+ """Constructs a BertTokenizer.
+
+ Args:
+ vocab_file: Path to a one-wordpiece-per-line vocabulary file
+ do_lower_case: Whether to lower case the input
+ Only has an effect when do_wordpiece_only=False
+ do_basic_tokenize: Whether to do basic tokenization before wordpiece.
+ max_len: An artificial maximum length to truncate tokenized sequences to;
+ Effective maximum length is always the minimum of this
+ value (if specified) and the underlying BERT model's
+ sequence length.
+ never_split: List of tokens which will never be split during tokenization.
+ Only has an effect when do_wordpiece_only=False
+ """
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict(
+ [(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
+ never_split=never_split)
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+ self.max_len = max_len if max_len is not None else int(1e12)
+
+ def tokenize(self, text):
+ if self.do_basic_tokenize:
+ split_tokens = []
+ for token in self.basic_tokenizer.tokenize(text):
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
+ split_tokens.append(sub_token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def convert_tokens_to_ids(self, tokens):
+ """Converts a sequence of tokens into ids using the vocab."""
+ ids = []
+ for token in tokens:
+ ids.append(self.vocab[token])
+ if len(ids) > self.max_len:
+ logger.warning(
+ "Token indices sequence length is longer than the specified maximum "
+ " sequence length for this BERT model ({} > {}). Running this"
+ " sequence through BERT will result in indexing errors".format(
+ len(ids), self.max_len)
+ )
+ return ids
+
+ def convert_ids_to_tokens(self, ids):
+ """Converts a sequence of ids in wordpiece tokens using the vocab."""
+ tokens = []
+ for i in ids:
+ tokens.append(self.ids_to_tokens[i])
+ return tokens
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
+ """
+ Instantiate a PreTrainedBertModel from a pre-trained model file.
+ Download and cache the pre-trained model file if needed.
+ """
+ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
+ vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
+ else:
+ vocab_file = pretrained_model_name_or_path
+ if os.path.isdir(vocab_file):
+ vocab_file = os.path.join(vocab_file, VOCAB_NAME)
+ # redirect to the cache, if necessary
+ try:
+ resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
+ except EnvironmentError:
+ logger.error(
+ "Model name '{}' was not found in model name list ({}). "
+ "We assumed '{}' was a path or url but couldn't find any file "
+ "associated to this path or url.".format(
+ pretrained_model_name_or_path,
+ ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
+ vocab_file))
+ return None
+ if resolved_vocab_file == vocab_file:
+ logger.info("loading vocabulary file {}".format(vocab_file))
+ else:
+ logger.info("loading vocabulary file {} from cache at {}".format(
+ vocab_file, resolved_vocab_file))
+ if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
+ # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
+ # than the number of positional embeddings
+ max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
+ kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
+ # Instantiate tokenizer.
+ tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
+ return tokenizer
+
+
+class BasicTokenizer(object):
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+
+ def __init__(self,
+ do_lower_case=True,
+ never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+ """Constructs a BasicTokenizer.
+
+ Args:
+ do_lower_case: Whether to lower case the input.
+ """
+ self.do_lower_case = do_lower_case
+ self.never_split = never_split
+
+ def tokenize(self, text):
+ """Tokenizes a piece of text."""
+ text = self._clean_text(text)
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ text = self._tokenize_chinese_chars(text)
+ orig_tokens = whitespace_tokenize(text)
+ split_tokens = []
+ for token in orig_tokens:
+ if self.do_lower_case and token not in self.never_split:
+ token = token.lower()
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text):
+ """Splits punctuation on a piece of text."""
+ if text in self.never_split:
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xfffd or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+class WordpieceTokenizer(object):
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """Tokenizes a piece of text into its word pieces.
+
+ This uses a greedy longest-match-first algorithm to perform tokenization
+ using the given vocabulary.
+
+ For example:
+ input = "unaffable"
+ output = ["un", "##aff", "##able"]
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through `BasicTokenizer`.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+def _is_whitespace(char):
+ """Checks whether `chars` is a whitespace character."""
+ # \t, \n, and \r are technically contorl characters but we treat them
+ # as whitespace since they are generally considered as such.
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
+ return True
+ cat = unicodedata.category(char)
+ if cat == "Zs":
+ return True
+ return False
+
+
+def _is_control(char):
+ """Checks whether `chars` is a control character."""
+ # These are technically control characters but we count them as whitespace
+ # characters.
+ if char == "\t" or char == "\n" or char == "\r":
+ return False
+ cat = unicodedata.category(char)
+ if cat.startswith("C"):
+ return True
+ return False
+
+
+def _is_punctuation(char):
+ """Checks whether `chars` is a punctuation character."""
+ cp = ord(char)
+ # We treat all non-letter/number ASCII as punctuation.
+ # Characters such as "^", "$", and "`" are not in the Unicode
+ # Punctuation class but we treat them as punctuation anyways, for
+ # consistency.
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
+ return True
+ cat = unicodedata.category(char)
+ if cat.startswith("P"):
+ return True
+ return False
diff --git a/megatron_lm/megatron/fp16/__init__.py b/megatron_lm/megatron/fp16/__init__.py
new file mode 100644
index 0000000..56ee11f
--- /dev/null
+++ b/megatron_lm/megatron/fp16/__init__.py
@@ -0,0 +1,30 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .fp16util import (
+ BN_convert_float,
+ network_to_half,
+ prep_param_lists,
+ model_grads_to_master_grads,
+ master_params_to_model_params,
+ tofp16,
+ to_python_float,
+ clip_grad_norm,
+ convert_module,
+ convert_network,
+ FP16Model,
+)
+
+from .fp16 import *
+from .loss_scaler import *
diff --git a/megatron_lm/megatron/fp16/fp16.py b/megatron_lm/megatron/fp16/fp16.py
new file mode 100755
index 0000000..bdea6ad
--- /dev/null
+++ b/megatron_lm/megatron/fp16/fp16.py
@@ -0,0 +1,651 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Stable version of apex FP16 Optimizer"""
+import torch
+from torch import nn
+from torch.autograd import Variable
+from torch.nn.parameter import Parameter
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+from .loss_scaler import DynamicLossScaler, LossScaler
+from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
+
+from apex.multi_tensor_apply import multi_tensor_applier
+import amp_C
+
+from megatron.module import MegatronModule
+
+FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
+HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
+
+
+def conversion_helper(val, conversion):
+ """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
+ if not isinstance(val, (tuple, list)):
+ return conversion(val)
+ rtn = [conversion_helper(v, conversion) for v in val]
+ if isinstance(val, tuple):
+ rtn = tuple(rtn)
+ return rtn
+
+
+def fp32_to_fp16(val):
+ """Convert fp32 `val` to fp16"""
+ def half_conversion(val):
+ val_typecheck = val
+ if isinstance(val_typecheck, (Parameter, Variable)):
+ val_typecheck = val.data
+ if isinstance(val_typecheck, FLOAT_TYPES):
+ val = val.half()
+ return val
+ return conversion_helper(val, half_conversion)
+
+
+def fp16_to_fp32(val):
+ """Convert fp16 `val` to fp32"""
+ def float_conversion(val):
+ val_typecheck = val
+ if isinstance(val_typecheck, (Parameter, Variable)):
+ val_typecheck = val.data
+ if isinstance(val_typecheck, HALF_TYPES):
+ val = val.float()
+ return val
+ return conversion_helper(val, float_conversion)
+
+
+class FP16_Module(MegatronModule):
+ def __init__(self, module):
+ super(FP16_Module, self).__init__()
+ self.add_module('module', module.half())
+
+ def forward(self, *inputs, **kwargs):
+ return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
+
+ def state_dict(self, destination=None, prefix='', keep_vars=False):
+ return self.module.state_dict(destination, prefix, keep_vars)
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ return self.module.state_dict_for_save_checkpoint(destination, prefix,
+ keep_vars)
+
+ def load_state_dict(self, state_dict, strict=True):
+ self.module.load_state_dict(state_dict, strict=strict)
+
+# TODO: Update overflow check + downscale to use Carl's fused kernel.
+
+
+class FP16_Optimizer(object):
+ """
+ :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
+ and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
+ For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance,
+ and changing the call to ``backward``.
+
+ Example::
+
+ model = torch.nn.Linear(D_in, D_out).cuda().half()
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
+ # Name the FP16_Optimizer instance to replace the existing optimizer
+ # (recommended but not required):
+ optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
+ ...
+ # loss.backward() becomes:
+ optimizer.backward(loss)
+ ...
+
+ Example with dynamic loss scaling::
+
+ ...
+ optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
+ # optional arg to control dynamic loss scaling behavior
+ # dynamic_loss_args={'scale_window' : 500})
+ # Usually, dynamic_loss_args is not necessary.
+
+ Args:
+ init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.
+ static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
+ dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option.
+ dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used.
+ verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
+
+ ``init_optimizer`` is expected to have been constructed in the ordinary way.
+ It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
+ named to replace ``init_optimizer``, for two reasons:
+ First, it means that references to the same name
+ later in the file will not have to change.
+ Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to
+ modify ``init_optimizer``. If you do choose a unique name for the new
+ :class:`FP16_Optimizer` instance, you should only work with this new instance,
+ because the preexisting optimizer might no longer behave as expected.
+
+ ``init_optimizer`` may be any Pytorch optimizer.
+ It may contain a mixture of fp16 and fp32 parameters organized into any number of
+ ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will
+ ingest these ``param_groups`` and remember them.
+
+ Calls to ::
+
+ loss.backward()
+
+ must be replaced with ::
+
+ optimizer.backward(loss)
+
+ because :class:`FP16_Optimizer` requires ownership of the backward pass to implement
+ loss scaling and copies to master gradients.
+
+ .. note::
+ Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
+ are downscaled before being applied. This means that adjusting the loss scale, or using
+ dynamic loss scaling, should not require retuning the learning rate or any other
+ hyperparameters.
+
+
+ **Advanced options**
+
+ **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure.
+ See docstring for :attr:`step`.
+
+ **Gradient clipping**: Use :attr:`clip_master_grads`.
+
+ **Multiple losses**: If your model accumulates gradients from multiple losses,
+ this can be made more efficient by supplying ``update_master_grads=False``
+ to :attr:`backward`. See docstring for :attr:`backward`.
+
+ **Manually adjusting loss scale**: The current loss scale can be retrieved or set via ::
+
+ print(optimizer.loss_scale)
+ optimizer.loss_scale = new_loss_scale
+
+ For static loss scaling, manually adjusting the loss scale over time is a reasonable
+ thing to do. During later epochs, gradients may become smaller, and a
+ higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss
+ scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting
+ the loss scale is not recommended.
+
+ **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
+ Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer`
+ should still work as intended.
+ """
+
+ def __init__(self,
+ init_optimizer,
+ static_loss_scale=1.0,
+ dynamic_loss_scale=False,
+ dynamic_loss_args=None,
+ verbose=False):
+ if not torch.cuda.is_available:
+ raise SystemError("Cannot use fp16 without CUDA.")
+
+ self.verbose = verbose
+
+ self.optimizer = init_optimizer
+ # init_state_dict sets up an alternative way to cast per-param state tensors.
+ # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
+ # init_state_dict = init_optimizer.state_dict()
+
+ self.fp16_groups = []
+ self.fp32_from_fp16_groups = []
+ self.fp32_from_fp32_groups = []
+ for i, param_group in enumerate(self.optimizer.param_groups):
+ self.maybe_print("FP16_Optimizer processing param group {}:".format(i))
+ fp16_params_this_group = []
+ fp32_params_this_group = []
+ fp32_from_fp16_params_this_group = []
+ for i, param in enumerate(param_group['params']):
+ if param.requires_grad:
+ if param.type() == 'torch.cuda.HalfTensor':
+ self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
+ .format(param.size()))
+ fp16_params_this_group.append(param)
+ master_param = param.detach().clone().float()
+ master_param.requires_grad = True
+ # Copythe model parallel flag.
+ master_param.model_parallel = param.model_parallel
+ param_group['params'][i] = master_param
+ fp32_from_fp16_params_this_group.append(master_param)
+ # Reset existing state dict key to the new master param.
+ # We still need to recast per-param state tensors, if any, to FP32.
+ if param in self.optimizer.state:
+ self.optimizer.state[master_param] = self.optimizer.state.pop(param)
+ elif param.type() == 'torch.cuda.FloatTensor':
+ self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
+ .format(param.size()))
+ fp32_params_this_group.append(param)
+ param_group['params'][i] = param
+ else:
+ raise TypeError("Wrapped parameters must be either "
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
+ "Received {}".format(param.type()))
+
+ self.fp16_groups.append(fp16_params_this_group)
+ self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
+ self.fp32_from_fp32_groups.append(fp32_params_this_group)
+
+ # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
+ self.optimizer.load_state_dict(self.optimizer.state_dict())
+ # alternative way to cast per-param state tensors:
+ # self.optimizer.load_state_dict(init_state_dict)
+
+ if dynamic_loss_scale:
+ self.dynamic_loss_scale = True
+ if dynamic_loss_args is not None:
+ self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
+ else:
+ self.loss_scaler = DynamicLossScaler()
+ else:
+ self.dynamic_loss_scale = False
+ self.loss_scaler = LossScaler(static_loss_scale)
+
+ self.overflow = False
+ self.first_closure_call_this_step = True
+
+ self.clip_grad_norm = clip_grad_norm
+
+ def maybe_print(self, msg):
+ if self.verbose:
+ print(msg)
+
+ def __getstate__(self):
+ raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
+
+ def __setstate__(self, state):
+ raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")
+
+ def zero_grad(self, set_grads_to_None=False):
+ """
+ Zero fp32 and fp16 parameter grads.
+ """
+ # In principle, only the .grad attributes of the model params need to be zeroed,
+ # because gradients are copied into the FP32 master params. However, we zero
+ # all gradients owned by the optimizer, just to be safe:
+ for group in self.optimizer.param_groups:
+ for p in group['params']:
+ if set_grads_to_None:
+ p.grad = None
+ else:
+ if p.grad is not None:
+ p.grad.detach_()
+ p.grad.zero_()
+
+ # Zero fp16 gradients owned by the model:
+ for fp16_group in self.fp16_groups:
+ for param in fp16_group:
+ if set_grads_to_None:
+ param.grad = None
+ else:
+ if param.grad is not None:
+ param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
+ param.grad.zero_()
+
+ def _check_overflow(self):
+ params = []
+ for group in self.fp16_groups:
+ for param in group:
+ params.append(param)
+ for group in self.fp32_from_fp32_groups:
+ for param in group:
+ params.append(param)
+ self.overflow = self.loss_scaler.has_overflow(params)
+
+ def _update_scale(self, has_overflow=False):
+ self.loss_scaler.update_scale(has_overflow)
+
+ def _master_params_to_model_params(self):
+ for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
+ master_params_to_model_params(fp16_group, fp32_from_fp16_group)
+
+ def _model_params_to_master_params(self):
+ for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
+ master_params_to_model_params(fp32_from_fp16_group, fp16_group)
+
+ # To consider: Integrate distributed with this wrapper by registering a hook on each variable
+ # that does the overflow check, gradient copy + downscale, and fp32
+ # allreduce in a different stream.
+ def _model_grads_to_master_grads(self):
+ for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
+ model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
+
+ def _downscale_master(self):
+ if self.loss_scale != 1.0:
+ for group in self.optimizer.param_groups:
+ grads = [p.grad for p in group['params'] if p.grad is not None]
+ _overflow_buf = torch.cuda.IntTensor([0])
+ multi_tensor_applier(amp_C.multi_tensor_scale,
+ _overflow_buf,
+ [grads, grads],
+ 1./self.loss_scale)
+
+ def clip_master_grads(self, max_norm, norm_type=2):
+ """
+ Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
+
+ Args:
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+
+ Returns:
+ Total norm of the current fp32 gradients (viewed as a single vector).
+
+ .. warning::
+ Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).
+ """
+ if not self.overflow:
+ fp32_params = []
+ for param_group in self.optimizer.param_groups:
+ for param in param_group['params']:
+ fp32_params.append(param)
+ return self.clip_grad_norm(fp32_params, max_norm, norm_type)
+ else:
+ return -1
+
+ def state_dict(self):
+ """
+ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
+ This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
+ of the contained Pytorch optimizer.
+ Example::
+
+ checkpoint = {}
+ checkpoint['model'] = model.state_dict()
+ checkpoint['optimizer'] = optimizer.state_dict()
+ torch.save(checkpoint, "saved.pth")
+ """
+ state_dict = {}
+ state_dict['loss_scaler'] = self.loss_scaler
+ state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
+ state_dict['overflow'] = self.overflow
+ state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
+ state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
+ state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups
+ return state_dict
+
+ def load_state_dict(self, state_dict):
+ """
+ Loads a state_dict created by an earlier call to state_dict().
+ If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
+ whose parameters in turn came from ``model``, it is expected that the user
+ will call ``model.load_state_dict()`` before
+ ``fp16_optimizer_instance.load_state_dict()`` is called.
+
+ Example::
+
+ model = torch.nn.Linear(D_in, D_out).cuda().half()
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
+ optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
+ ...
+ checkpoint = torch.load("saved.pth")
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ """
+ # I think it should actually be ok to reload the optimizer before the model.
+ self.loss_scaler = state_dict['loss_scaler']
+ self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
+ self.overflow = state_dict['overflow']
+ self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
+ self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
+ # At this point, the optimizer's references to the model's fp32 parameters are up to date.
+ # The optimizer's hyperparameters and internal buffers are also up to date.
+ # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
+ # out of date. There are two options.
+ # 1: Refresh the master params from the model's fp16 params.
+ # This requires less storage but incurs precision loss.
+ # 2: Save and restore the fp32 master copies separately.
+ # We choose option 2.
+ #
+ # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
+ # of their associated parameters, because it's possible those buffers might not exist yet in
+ # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
+ # constructed in the same way as the one whose state_dict we are loading, the same master params
+ # are guaranteed to exist, so we can just copy_() from the saved master params.
+ for current_group, saved_group in zip(
+ self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
+ for current, saved in zip(current_group, saved_group):
+ current.data.copy_(saved.data)
+
+ def step(self, closure=None): # could add clip option.
+ """
+ If no closure is supplied, :attr:`step` should be called after
+ ``fp16_optimizer_obj.backward(loss)``.
+ :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
+ :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
+ originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
+ another forward pass using their model.
+
+ If a closure is supplied, :attr:`step` may be called without a prior call to
+ :attr:`backward(loss)`.
+ This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
+ However, the user should take care that any ``loss.backward()`` call within the closure
+ has been replaced by ``fp16_optimizer_obj.backward(loss)``.
+
+ Args:
+ closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss.
+
+ Example with closure::
+
+ # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
+ # existing pytorch optimizer.
+ for input, target in dataset:
+ def closure():
+ optimizer.zero_grad()
+ output = model(input)
+ loss = loss_fn(output, target)
+ # loss.backward() becomes:
+ optimizer.backward(loss)
+ return loss
+ optimizer.step(closure)
+
+ .. warning::
+ Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.
+
+ .. _`ordinary Pytorch optimizer use`:
+ http://pytorch.org/docs/master/optim.html#optimizer-step-closure
+ """
+
+ scale = self.loss_scaler.loss_scale
+ self._update_scale(self.overflow)
+
+ if self.overflow:
+ self.maybe_print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}"
+ .format(scale, self.loss_scale))
+ return
+
+ if closure is not None:
+ retval = self._step_with_closure(closure)
+ else:
+ retval = self.optimizer.step()
+
+ self._master_params_to_model_params()
+
+ return retval
+
+ def _step_with_closure(self, closure):
+ def wrapped_closure():
+ # helpful for debugging
+ # print("Calling wrapped_closure, first_closure_call_this_step = {}"
+ # .format(self.first_closure_call_this_step))
+ if self.first_closure_call_this_step:
+ # We expect that the fp16 params are initially fresh on entering self.step(),
+ # so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
+ # is called within self.optimizer.step().
+ self.first_closure_call_this_step = False
+ else:
+ # If self.optimizer.step() internally calls wrapped_closure more than once,
+ # it may update the fp32 params after each call. However, self.optimizer
+ # doesn't know about the fp16 params at all. If the fp32 params get updated,
+ # we can't rely on self.optimizer to refresh the fp16 params. We need
+ # to handle that manually:
+ self._master_params_to_model_params()
+ # Our API expects the user to give us ownership of the backward() call by
+ # replacing all calls to loss.backward() with optimizer.backward(loss).
+ # This requirement holds whether or not the call to backward() is made within a closure.
+ # If the user is properly calling optimizer.backward(loss) within "closure,"
+ # calling closure() here will give the fp32 master params fresh gradients
+ # for the optimizer to play with, so all wrapped_closure needs to do is call
+ # closure() and return the loss.
+ temp_loss = closure()
+ while(self.overflow):
+ scale = self.loss_scaler.loss_scale
+ self._update_scale(self.overflow)
+ self.maybe_print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
+ "reducing to {}".format(scale, self.loss_scale))
+ temp_loss = closure()
+ return temp_loss
+
+ retval = self.optimizer.step(wrapped_closure)
+
+ self.first_closure_call_this_step = True
+
+ return retval
+
+ def backward(self, loss, update_master_grads=True, retain_graph=False):
+ """
+ :attr:`backward` performs the following conceptual steps:
+
+ 1. fp32_loss = loss.float() (see first Note below)
+ 2. scaled_loss = fp32_loss*loss_scale
+ 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).
+ 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32.
+ 5. Finally, master grads are divided by loss_scale.
+
+ In this way, after :attr:`backward`, the master params have fresh gradients,
+ and :attr:`step` may be called.
+
+ .. note::
+ :attr:`backward` internally converts the loss to fp32 before applying the loss scale.
+ This provides some additional safety against overflow if the user has supplied an
+ fp16 loss value.
+ However, for maximum overflow safety, the user should
+ compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
+ :attr:`backward`.
+
+ .. warning::
+ The gradients found in a model's leaves after the call to
+ :attr:`backward` should not be regarded as valid in general,
+ because it's possible
+ they have been scaled (and in the case of dynamic loss scaling,
+ the scale factor may change over time).
+ If the user wants to inspect gradients after a call to :attr:`backward`,
+ only the master gradients should be regarded as valid. These can be retrieved via
+ :attr:`inspect_master_grad_data()`.
+
+ Args:
+ loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
+ update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.
+ retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).
+
+ Example::
+
+ # Ordinary operation:
+ optimizer.backward(loss)
+
+ # Naive operation with multiple losses (technically valid, but less efficient):
+ # fp32 grads will be correct after the second call, but
+ # the first call incurs an unnecessary fp16->fp32 grad copy.
+ optimizer.backward(loss1)
+ optimizer.backward(loss2)
+
+ # More efficient way to handle multiple losses:
+ # The fp16->fp32 grad copy is delayed until fp16 grads from all
+ # losses have been accumulated.
+ optimizer.backward(loss1, update_master_grads=False)
+ optimizer.backward(loss2, update_master_grads=False)
+ optimizer.update_master_grads()
+ """
+ # To consider: try multiple backward passes using retain_grad=True to find
+ # a loss scale that works. After you find a loss scale that works, do a final dummy
+ # backward pass with retain_graph=False to tear down the graph. Doing this would avoid
+ # discarding the iteration, but probably wouldn't improve overall efficiency.
+ self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
+ if update_master_grads:
+ self.update_master_grads()
+
+ def update_master_grads(self):
+ """
+ Copy the ``.grad`` attribute from stored references to fp16 parameters to
+ the ``.grad`` attribute of the fp32 master parameters that are directly
+ updated by the optimizer. :attr:`update_master_grads` only needs to be called if
+ ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
+ """
+ if self.dynamic_loss_scale:
+ self._check_overflow()
+ if self.overflow:
+ return
+ self._model_grads_to_master_grads()
+ self._downscale_master()
+
+ def inspect_master_grad_data(self):
+ """
+ When running with :class:`FP16_Optimizer`,
+ ``.grad`` attributes of a model's fp16 leaves should not be
+ regarded as truthful, because they might be scaled.
+ After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
+ the fp32 master params' ``.grad``
+ attributes will contain valid gradients properly divided by the loss scale. However,
+ because :class:`FP16_Optimizer` flattens some parameters, accessing them may be
+ nonintuitive. :attr:`inspect_master_grad_data`
+ allows those gradients to be viewed with shapes corresponding to their associated model leaves.
+
+ Returns:
+ List of lists (one list for each parameter group). The list for each parameter group
+ is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
+ """
+ if self.overflow:
+ print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. "
+ "Gradients are currently invalid (may be inf, nan, or stale). Returning None.")
+ return None
+ else:
+ # The optimizer owns only references to master params.
+ master_grads_data = []
+ for param_group in self.optimizer.param_groups:
+ master_grads_this_group = []
+ for param in param_group['params']:
+ if param.grad is not None:
+ master_grads_this_group.append(param.grad.data)
+ else:
+ master_grads_this_group.append(None)
+ master_grads_data.append(master_grads_this_group)
+ return master_grads_data
+
+ # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
+
+ def _get_loss_scale(self):
+ return self.loss_scaler.loss_scale
+
+ def _set_loss_scale(self, value):
+ self.loss_scaler.cur_scale = value
+
+ loss_scale = property(_get_loss_scale, _set_loss_scale)
+
+ # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
+ def _get_state(self):
+ return self.optimizer.state
+
+ def _set_state(self, value):
+ self.optimizer.state = value
+
+ state = property(_get_state, _set_state)
+
+ # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
+ # (for example, to adjust the learning rate)
+ def _get_param_groups(self):
+ return self.optimizer.param_groups
+
+ def _set_param_groups(self, value):
+ self.optimizer.param_groups = value
+
+ param_groups = property(_get_param_groups, _set_param_groups)
diff --git a/megatron_lm/megatron/fp16/fp16util.py b/megatron_lm/megatron/fp16/fp16util.py
new file mode 100644
index 0000000..0266ede
--- /dev/null
+++ b/megatron_lm/megatron/fp16/fp16util.py
@@ -0,0 +1,216 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+from apex.multi_tensor_apply import multi_tensor_applier
+import amp_C
+
+from megatron import mpu
+
+
+class tofp16(nn.Module):
+ """
+ Utility module that implements::
+
+ def forward(self, input):
+ return input.half()
+ """
+
+ def __init__(self):
+ super(tofp16, self).__init__()
+
+ def forward(self, input):
+ return input.half()
+
+
+def BN_convert_float(module):
+ """
+ Utility function for network_to_half().
+
+ Retained for legacy purposes.
+ """
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
+ module.float()
+ for child in module.children():
+ BN_convert_float(child)
+ return module
+
+
+def network_to_half(network):
+ """
+ Convert model to half precision in a batchnorm-safe way.
+
+ Retained for legacy purposes. It is recommended to use FP16Model.
+ """
+ return nn.Sequential(tofp16(), BN_convert_float(network.half()))
+
+
+def convert_module(module, dtype):
+ """
+ Converts a module's immediate parameters and buffers to dtype.
+ """
+ for param in module.parameters(recurse=False):
+ if param is not None:
+ if param.data.dtype.is_floating_point:
+ param.data = param.data.to(dtype=dtype)
+ if param._grad is not None and param._grad.data.dtype.is_floating_point:
+ param._grad.data = param._grad.data.to(dtype=dtype)
+
+ for buf in module.buffers(recurse=False):
+ if buf is not None and buf.data.dtype.is_floating_point:
+ buf.data = buf.data.to(dtype=dtype)
+
+
+def convert_network(network, dtype):
+ """
+ Converts a network's parameters and buffers to dtype.
+ """
+ for module in network.modules():
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
+ continue
+ convert_module(module, dtype)
+ return network
+
+
+class FP16Model(nn.Module):
+ """
+ Convert model to half precision in a batchnorm-safe way.
+ """
+
+ def __init__(self, network):
+ super(FP16Model, self).__init__()
+ self.network = convert_network(network, dtype=torch.half)
+
+ def forward(self, *inputs):
+ inputs = tuple(t.half() for t in inputs)
+ return self.network(*inputs)
+
+
+def backwards_debug_hook(grad):
+ raise RuntimeError("master_params recieved a gradient in the backward pass!")
+
+
+def prep_param_lists(model, flat_master=False):
+ """
+ Creates a list of FP32 master parameters for a given model, as in
+ `Training Neural Networks with Mixed Precision: Real Examples`_.
+
+ Args:
+ model (torch.nn.Module): Existing Pytorch model
+ flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization.
+ Returns:
+ A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element.
+
+ Example::
+
+ model_params, master_params = prep_param_lists(model)
+
+ .. warning::
+ Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.
+
+ .. _`Training Neural Networks with Mixed Precision: Real Examples`:
+ http://on-demand.gputechconf.com/gtc/2018/video/S81012/
+ """
+ model_params = [param for param in model.parameters() if param.requires_grad]
+
+ if flat_master:
+ # Give the user some more useful error messages
+ try:
+ # flatten_dense_tensors returns a contiguous flat array.
+ # http://pytorch.org/docs/master/_modules/torch/_utils.html
+ master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
+ except BaseException:
+ print("Error in prep_param_lists: model may contain a mixture of parameters "
+ "of different types. Use flat_master=False, or use F16_Optimizer.")
+ raise
+ master_params = torch.nn.Parameter(master_params)
+ master_params.requires_grad = True
+ # master_params.register_hook(backwards_debug_hook)
+ if master_params.grad is None:
+ master_params.grad = master_params.new(*master_params.size())
+ return model_params, [master_params]
+ else:
+ master_params = [param.clone().float().detach() for param in model_params]
+ for param in master_params:
+ param.requires_grad = True
+ return model_params, master_params
+
+
+def model_grads_to_master_grads(model_params, master_params, flat_master=False):
+ """
+ Copy model gradients to master gradients.
+
+ Args:
+ model_params: List of model parameters created by :func:`prep_param_lists`.
+ master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.
+ """
+ if flat_master:
+ # The flattening may incur one more deep copy than is necessary.
+ master_params[0].grad.data.copy_(
+ _flatten_dense_tensors([p.grad.data for p in model_params]))
+ else:
+ for model, master in zip(model_params, master_params):
+ if model.grad is not None:
+ if master.grad is None:
+ master.grad = Variable(master.data.new(*master.data.size()))
+ else:
+ master.grad = None
+ model_grads = [p.grad for p in model_params if p.grad is not None]
+ master_grads = [p.grad for p in master_params if p.grad is not None]
+ _overflow_buf = torch.cuda.IntTensor([0])
+ multi_tensor_applier(amp_C.multi_tensor_scale,
+ _overflow_buf,
+ [model_grads, master_grads],
+ 1.0)
+
+
+def master_params_to_model_params(model_params, master_params, flat_master=False):
+ """
+ Copy master parameters to model parameters.
+
+ Args:
+ model_params: List of model parameters created by :func:`prep_param_lists`.
+ master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
+ """
+ if flat_master:
+ for model, master in zip(model_params,
+ _unflatten_dense_tensors(master_params[0].data, model_params)):
+ model.data.copy_(master)
+ else:
+ for model, master in zip(model_params, master_params):
+ model.data.copy_(master.data)
+
+# Backward compatibility fixes
+
+
+def to_python_float(t):
+ if hasattr(t, 'item'):
+ return t.item()
+ else:
+ return t[0]
+
+
+TORCH_MAJOR = int(torch.__version__.split('.')[0])
+TORCH_MINOR = int(torch.__version__.split('.')[1])
+
+clip_grad_norm = mpu.clip_grad_norm
+# elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
+# clip_grad_norm = torch.nn.utils.clip_grad_norm
+# else:
+# clip_grad_norm = torch.nn.utils.clip_grad_norm_
diff --git a/megatron_lm/megatron/fp16/loss_scaler.py b/megatron_lm/megatron/fp16/loss_scaler.py
new file mode 100755
index 0000000..126b786
--- /dev/null
+++ b/megatron_lm/megatron/fp16/loss_scaler.py
@@ -0,0 +1,256 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from apex.multi_tensor_apply import multi_tensor_applier
+import amp_C
+
+from megatron import mpu
+
+# item() is a recent addition, so this helps with backward compatibility.
+
+
+def to_python_float(t):
+ if hasattr(t, 'item'):
+ return t.item()
+ else:
+ return t[0]
+
+
+class LossScaler:
+ """
+ Class that manages a static loss scale. This class is intended to interact with
+ :class:`FP16_Optimizer`, and should not be directly manipulated by the user.
+
+ Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
+ :class:`FP16_Optimizer`'s constructor.
+
+ Args:
+ scale (float, optional, default=1.0): The loss scale.
+ """
+
+ def __init__(self, scale=1):
+ self.cur_scale = scale
+
+ # `params` is a list / generator of torch.Variable
+ def has_overflow(self, params):
+ return False
+
+ # `x` is a torch.Tensor
+ def _has_inf_or_nan(x):
+ return False
+
+ def update_scale(self, overflow):
+ pass
+
+ @property
+ def loss_scale(self):
+ return self.cur_scale
+
+ def scale_gradient(self, module, grad_in, grad_out):
+ _overflow_buf = torch.cuda.IntTensor([0])
+ multi_tensor_applier(amp_C.multi_tensor_scale,
+ _overflow_buf,
+ [grad_in, grad_in],
+ self.loss_scale)
+ return grad_in
+
+ def backward(self, loss, retain_graph=False):
+ scaled_loss = loss * self.loss_scale
+ scaled_loss.backward(retain_graph=retain_graph)
+
+
+class DynamicLossScaler:
+ """
+ Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
+ indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
+ :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
+ operates, because the default options can be changed using the
+ the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
+
+ Loss scaling is designed to combat the problem of underflowing gradients encountered at long
+ times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
+ scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
+ encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
+ occurred.
+ :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
+ and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
+ If a certain number of iterations occur without overflowing gradients detected,
+ :class:`DynamicLossScaler` increases the loss scale once more.
+ In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
+ always using the highest loss scale possible without incurring overflow.
+
+ Args:
+ init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
+ scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
+ scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
+ """
+
+ def __init__(self,
+ init_scale=2**32,
+ scale_factor=2.,
+ scale_window=1000,
+ min_scale=1,
+ delayed_shift=1,
+ consecutive_hysteresis=False):
+ self.cur_scale = init_scale
+ self.cur_iter = 0
+ self.last_overflow_iter = -1
+ self.scale_factor = scale_factor
+ self.scale_window = scale_window
+ self.min_scale = min_scale
+ self.delayed_shift = delayed_shift
+ self.cur_hysteresis = delayed_shift
+ self.consecutive_hysteresis = consecutive_hysteresis
+
+ # `params` is a list / generator of torch.Variable
+ def has_overflow_serial(self, params):
+ for p in params:
+ if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
+ return True
+
+ return False
+
+ def has_overflow(self, params):
+ overflow = self.has_overflow_serial(params)
+ # Since each model parallel GPU carries only part of the model,
+ # make sure overflow flag is synced across all the model parallel GPUs
+ overflow_gpu = torch.cuda.ByteTensor([overflow])
+ torch.distributed.all_reduce(overflow_gpu,
+ op=torch.distributed.ReduceOp.MAX,
+ group=mpu.get_model_parallel_group())
+ overflow = overflow_gpu[0].item()
+ return bool(overflow)
+
+ # `x` is a torch.Tensor
+
+ def _has_inf_or_nan(x):
+ try:
+ # if x is half, the .float() incurs an additional deep copy, but it's necessary if
+ # Pytorch's .sum() creates a one-element tensor of the same type as x
+ # (which is true for some recent version of pytorch).
+ cpu_sum = float(x.float().sum())
+ # More efficient version that can be used if .sum() returns a Python scalar
+ # cpu_sum = float(x.sum())
+ except RuntimeError as instance:
+ # We want to check if inst is actually an overflow exception.
+ # RuntimeError could come from a different error.
+ # If so, we still want the exception to propagate.
+ if "value cannot be converted" not in instance.args[0]:
+ raise
+ return True
+ else:
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
+ return True
+ return False
+
+ # `overflow` is boolean indicating whether the gradient overflowed
+ def update_scale(self, overflow):
+
+ if not hasattr(self, 'min_scale'):
+ self.min_scale = 1
+ if not hasattr(self, 'delayed_shift'):
+ self.delayed_shift = 1
+ if not hasattr(self, 'cur_hysteresis'):
+ self.cur_hysteresis = 1
+ if not hasattr(self, 'consecutive_hysteresis'):
+ self.consecutive_hysteresis = True
+ if overflow:
+ # self.cur_scale /= self.scale_factor
+ if self.delayed_shift == 1 or self.cur_hysteresis == 1:
+ self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
+ else:
+ self.cur_hysteresis -= 1
+ self.last_overflow_iter = self.cur_iter
+ else:
+ if self.consecutive_hysteresis:
+ self.cur_hysteresis = self.delayed_shift
+ if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
+ if not self.consecutive_hysteresis:
+ self.cur_hysteresis = self.delayed_shift
+ self.cur_scale *= self.scale_factor
+ self.cur_iter += 1
+
+ @property
+ def loss_scale(self):
+ return self.cur_scale
+
+ def scale_gradient(self, module, grad_in, grad_out):
+ _overflow_buf = torch.cuda.IntTensor([0])
+ multi_tensor_applier(amp_C.multi_tensor_scale,
+ _overflow_buf,
+ [grad_in, grad_in],
+ self.loss_scale)
+ return grad_in
+
+ def backward(self, loss, retain_graph=False):
+ scaled_loss = loss * self.loss_scale
+ scaled_loss.backward(retain_graph=retain_graph)
+
+
+##############################################################
+# Example usage below here -- assuming it's in a separate file
+##############################################################
+"""
+TO-DO separate out into an example.
+if __name__ == "__main__":
+ import torch
+ from torch.autograd import Variable
+ from dynamic_loss_scaler import DynamicLossScaler
+
+ # N is batch size; D_in is input dimension;
+ # H is hidden dimension; D_out is output dimension.
+ N, D_in, H, D_out = 64, 1000, 100, 10
+
+ # Create random Tensors to hold inputs and outputs, and wrap them in Variables.
+ x = Variable(torch.randn(N, D_in), requires_grad=False)
+ y = Variable(torch.randn(N, D_out), requires_grad=False)
+
+ w1 = Variable(torch.randn(D_in, H), requires_grad=True)
+ w2 = Variable(torch.randn(H, D_out), requires_grad=True)
+ parameters = [w1, w2]
+
+ learning_rate = 1e-6
+ optimizer = torch.optim.SGD(parameters, lr=learning_rate)
+ loss_scaler = DynamicLossScaler()
+
+ for t in range(500):
+ y_pred = x.mm(w1).clamp(min=0).mm(w2)
+ loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
+ print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
+ print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
+ print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
+
+ # Run backprop
+ optimizer.zero_grad()
+ loss.backward()
+
+ # Check for overflow
+ has_overflow = DynamicLossScaler.has_overflow(parameters)
+
+ # If no overflow, unscale grad and update as usual
+ if not has_overflow:
+ for param in parameters:
+ param.grad.data.mul_(1. / loss_scaler.loss_scale)
+ optimizer.step()
+ # Otherwise, don't do anything -- ie, skip iteration
+ else:
+ print('OVERFLOW!')
+
+ # Update loss scale for next iteration
+ loss_scaler.update_scale(has_overflow)
+
+"""
diff --git a/megatron_lm/megatron/fused_kernels/__init__.py b/megatron_lm/megatron/fused_kernels/__init__.py
new file mode 100644
index 0000000..8d5d863
--- /dev/null
+++ b/megatron_lm/megatron/fused_kernels/__init__.py
@@ -0,0 +1,100 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pathlib
+import subprocess
+import os
+from torch.utils import cpp_extension
+
+# Setting this param to a list has a problem of generating
+# different compilation commands (with diferent order of architectures)
+# and leading to recompilation of fused kernels.
+# set it to empty string to avoid recompilation
+# and assign arch flags explicity in extra_cuda_cflags below
+os.environ["TORCH_CUDA_ARCH_LIST"] = ""
+
+def get_cuda_bare_metal_version(cuda_dir):
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
+ universal_newlines=True)
+ output = raw_output.split()
+ release_idx = output.index("release") + 1
+ release = output[release_idx].split(".")
+ bare_metal_major = release[0]
+ bare_metal_minor = release[1][0]
+
+ return raw_output, bare_metal_major, bare_metal_minor
+
+def create_build_dir(buildpath):
+ try:
+ os.mkdir(buildpath)
+ except OSError:
+ if not os.path.isdir(buildpath):
+ print(f"Creation of the build directory {buildpath} failed")
+
+def load_scaled_upper_triang_masked_softmax_fusion_kernel():
+
+ # Check, if CUDA11 is installed for compute capability 8.0
+ cc_flag = []
+ _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
+ if int(bare_metal_major) >= 11:
+ cc_flag.append('-gencode')
+ cc_flag.append('arch=compute_80,code=sm_80')
+
+ srcpath = pathlib.Path(__file__).parent.absolute()
+ buildpath = srcpath / 'build'
+
+ create_build_dir(buildpath)
+
+ scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
+ name='scaled_upper_triang_masked_softmax_cuda',
+ sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
+ srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'],
+ build_directory=buildpath,
+ extra_cflags=['-O3',],
+ extra_cuda_cflags=['-O3',
+ '-gencode', 'arch=compute_70,code=sm_70',
+ '-U__CUDA_NO_HALF_OPERATORS__',
+ '-U__CUDA_NO_HALF_CONVERSIONS__',
+ '--expt-relaxed-constexpr',
+ '--expt-extended-lambda',
+ '--use_fast_math'] + cc_flag)
+
+def load_scaled_masked_softmax_fusion_kernel():
+
+ # Check, if CUDA11 is installed for compute capability 8.0
+ cc_flag = []
+ _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
+ if int(bare_metal_major) >= 11:
+ cc_flag.append('-gencode')
+ cc_flag.append('arch=compute_80,code=sm_80')
+
+ srcpath = pathlib.Path(__file__).parent.absolute()
+ buildpath = srcpath / 'build'
+
+ create_build_dir(buildpath)
+
+ scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
+ name='scaled_masked_softmax_cuda',
+ sources=[srcpath / 'scaled_masked_softmax.cpp',
+ srcpath / 'scaled_masked_softmax_cuda.cu'],
+ build_directory=buildpath,
+ extra_cflags=['-O3',],
+ extra_cuda_cflags=['-O3',
+ '-gencode', 'arch=compute_70,code=sm_70',
+ '-U__CUDA_NO_HALF_OPERATORS__',
+ '-U__CUDA_NO_HALF_CONVERSIONS__',
+ '--expt-relaxed-constexpr',
+ '--expt-extended-lambda',
+ '--use_fast_math'] + cc_flag)
diff --git a/megatron_lm/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron_lm/megatron/fused_kernels/scaled_masked_softmax.cpp
new file mode 100644
index 0000000..87a55df
--- /dev/null
+++ b/megatron_lm/megatron/fused_kernels/scaled_masked_softmax.cpp
@@ -0,0 +1,74 @@
+/* coding=utf-8
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <torch/extension.h>
+#include <vector>
+
+namespace multihead_attn {
+namespace fused_softmax {
+namespace scaled_masked_softmax {
+
+torch::Tensor fwd_cuda(
+ torch::Tensor const& input,
+ torch::Tensor const& mask,
+ float scale_factor);
+
+torch::Tensor bwd_cuda(
+ torch::Tensor const& output_grads,
+ torch::Tensor const& softmax_results,
+ float scale_factor);
+
+torch::Tensor fwd(
+ torch::Tensor const& input,
+ torch::Tensor const& mask,
+ float scale_factor) {
+ AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
+ AT_ASSERTM(input.scalar_type() == at::ScalarType::Half,
+ "Only HALF is supported");
+ AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
+
+ return fwd_cuda(input, mask, scale_factor);
+}
+
+torch::Tensor bwd(
+ torch::Tensor const& output_grads,
+ torch::Tensor const& softmax_results,
+ float scale_factor) {
+
+ AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
+ AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
+
+ AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half,
+ "Only HALF is supported");
+ AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half,
+ "Only HALF is supported");
+
+ return bwd_cuda(output_grads, softmax_results, scale_factor);
+}
+
+} // end namespace scaled_masked_softmax
+} // end namespace fused_softmax
+} // end namespace multihead_attn
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward",
+ &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
+ "Self Multihead Attention scaled, time masked softmax -- Forward.");
+ m.def("backward",
+ &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
+ "Self Multihead Attention scaled, time masked softmax -- Backward.");
+}
diff --git a/megatron_lm/megatron/fused_kernels/scaled_masked_softmax.h b/megatron_lm/megatron/fused_kernels/scaled_masked_softmax.h
new file mode 100644
index 0000000..c327a1b
--- /dev/null
+++ b/megatron_lm/megatron/fused_kernels/scaled_masked_softmax.h
@@ -0,0 +1,452 @@
+/* coding=utf-8
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <assert.h>
+#include <cuda_fp16.h>
+#include <cfloat>
+#include <limits>
+#include <stdint.h>
+#include <cuda_fp16.h>
+#include <c10/macros/Macros.h>
+
+namespace {
+
+int log2_ceil(int value) {
+ int log2_value = 0;
+ while ((1 << log2_value) < value) ++log2_value;
+ return log2_value;
+}
+
+template<typename T>
+struct Add {
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return a + b;
+ }
+};
+
+template<typename T>
+struct Max {
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return a < b ? b : a;
+ }
+};
+
+template <typename T>
+__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
+{
+#if CUDA_VERSION >= 9000
+ return __shfl_xor_sync(mask, value, laneMask, width);
+#else
+ return __shfl_xor(value, laneMask, width);
+#endif
+}
+
+template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
+__device__ __forceinline__ void warp_reduce(acc_t* sum) {
+ ReduceOp<acc_t> r;
+ #pragma unroll
+ for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
+ sum[i] = r(sum[i], b);
+ }
+ }
+}
+
+/*
+ * Extended softmax (from native aten pytorch) with following additional features
+ * 1) input scaling
+ * 2) Explicit masking
+ */
+template <typename input_t, typename output_t, typename acc_t, int log2_elements>
+__global__ void scaled_masked_softmax_warp_forward(
+ output_t *dst,
+ const input_t *src,
+ const uint8_t *mask,
+ const acc_t scale,
+ int batch_size,
+ int stride,
+ int element_count,
+ int pad_batches)
+{
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
+ // warp_size of method warp_softmax_forward_kernel.
+ constexpr int next_power_of_two = 1 << log2_elements;
+ constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+ constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+ constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+
+ // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
+ // gridDim/blockIdx = (seq_len, attn_heads, batches)
+ int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
+ int pad_first_batch = 0;
+ if (pad_batches != 1) { // bert style
+ pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
+ } else { // gpt2 style
+ pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
+ }
+
+ // batch_size might not be a multiple of WARP_BATCH. Check how
+ // many batches have to computed within this WARP.
+ int local_batches = batch_size - first_batch;
+ if (local_batches > WARP_BATCH)
+ local_batches = WARP_BATCH;
+
+ // there might be multiple batches per warp. compute the index within the batch
+ int local_idx = threadIdx.x;
+
+ src += first_batch * stride + local_idx;
+ dst += first_batch * stride + local_idx;
+ mask += pad_first_batch * stride + local_idx;
+
+ // load data from global memory
+ acc_t elements[WARP_BATCH][WARP_ITERATIONS];
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ int batch_element_count = (i >= local_batches) ? 0 : element_count;
+
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ int itr_idx = i*element_count+it*WARP_SIZE;
+
+ if (element_index < batch_element_count) {
+ if (mask[itr_idx] != 1) {
+ elements[i][it] = (acc_t)src[itr_idx] * scale;
+ } else {
+ elements[i][it] = -10000.0;
+ }
+ } else {
+ elements[i][it] = -std::numeric_limits<acc_t>::infinity();
+ }
+ }
+ }
+
+ // compute max_value
+ acc_t max_value[WARP_BATCH];
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ max_value[i] = elements[i][0];
+ #pragma unroll
+ for (int it = 1; it < WARP_ITERATIONS; ++it) {
+ max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
+ }
+ }
+ warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
+
+ acc_t sum[WARP_BATCH] { 0.0f };
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ elements[i][it] = std::exp((elements[i][it] - max_value[i]));
+ sum[i] += elements[i][it];
+ }
+ }
+ warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
+
+ // store result
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ if (i >= local_batches)
+ break;
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ if (element_index < element_count) {
+ dst[i*element_count+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]);
+ } else {
+ break;
+ }
+ }
+ }
+}
+
+template <typename input_t, typename output_t, typename acc_t, int log2_elements>
+__global__ void scaled_masked_softmax_warp_backward(
+ output_t *gradInput,
+ input_t *grad,
+ const input_t *output,
+ acc_t scale,
+ int batch_size,
+ int stride,
+ int element_count)
+{
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
+ // warp_size of method warp_softmax_backward_kernel.
+ constexpr int next_power_of_two = 1 << log2_elements;
+ constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+ constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+ constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+
+ // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
+ // gridDim/blockIdx = (seq_len, attn_heads, batches)
+ int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
+
+ // batch_size might not be a multiple of WARP_BATCH. Check how
+ // many batches have to computed within this WARP.
+ int local_batches = batch_size - first_batch;
+ if (local_batches > WARP_BATCH)
+ local_batches = WARP_BATCH;
+
+ // there might be multiple batches per warp. compute the index within the batch
+ int local_idx = threadIdx.x;
+
+ // the first element to process by the current thread
+ int thread_offset = first_batch * stride + local_idx;
+ grad += thread_offset;
+ output += thread_offset;
+ gradInput += thread_offset;
+
+ // load data from global memory
+ acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
+ acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ int batch_element_count = (i >= local_batches) ? 0 : element_count;
+
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ if (element_index < batch_element_count) {
+ output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
+ } else {
+ output_reg[i][it] = acc_t(0);
+ }
+ }
+
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ if (element_index < batch_element_count) {
+ grad_reg[i][it] = (acc_t)grad[i*element_count+it*WARP_SIZE] * output_reg[i][it];
+ } else {
+ grad_reg[i][it] = acc_t(0);
+ }
+ }
+ }
+
+ acc_t sum[WARP_BATCH];
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ sum[i] = grad_reg[i][0];
+ #pragma unroll
+ for (int it = 1; it < WARP_ITERATIONS; ++it) {
+ sum[i] += grad_reg[i][it];
+ }
+ }
+ warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
+
+ // store result
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ if (i >= local_batches)
+ break;
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ if (element_index < element_count) {
+ // compute gradients
+ gradInput[i*element_count+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i]));
+ }
+ }
+ }
+}
+
+} // end of anonymous namespace
+
+template<typename input_t, typename output_t, typename acc_t>
+void dispatch_scaled_masked_softmax_forward(
+ output_t *dst,
+ const input_t *src,
+ const uint8_t *mask,
+ const input_t scale,
+ int softmax_elements,
+ int softmax_elements_stride,
+ int batches,
+ int attn_heads,
+ int pad_batches)
+{
+ TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
+ if (softmax_elements == 0) {
+ return;
+ } else {
+ int log2_elements = log2_ceil(softmax_elements);
+ const int next_power_of_two = 1 << log2_elements;
+ int seq_len = softmax_elements;
+ int batch_count = batches * attn_heads * seq_len;
+
+ // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
+ int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+
+ // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
+ int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+ // use 128 threads per block to maximimize gpu utilization
+ constexpr int threads_per_block = 128;
+
+ int warps_per_block = (threads_per_block / warp_size);
+ int batches_per_block = warps_per_block * batches_per_warp;
+ TORCH_INTERNAL_ASSERT(seq_len%batches_per_block == 0);
+ dim3 blocks(seq_len/batches_per_block, attn_heads, batches);
+ dim3 threads(warp_size, warps_per_block, 1);
+ // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+ switch (log2_elements) {
+ case 0: // 1
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 1: // 2
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 2: // 4
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 3: // 8
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 4: // 16
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 5: // 32
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 6: // 64
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 7: // 128
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 8: // 256
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 9: // 512
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 10: // 1024
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ case 11: // 2048
+ scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+template<typename input_t, typename output_t, typename acc_t>
+void dispatch_scaled_masked_softmax_backward(
+ output_t *grad_input,
+ input_t *grad,
+ const input_t *output,
+ const acc_t scale,
+ int softmax_elements,
+ int softmax_elements_stride,
+ int batches,
+ int attn_heads)
+{
+ TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
+ if (softmax_elements == 0) {
+ return;
+ } else {
+ int log2_elements = log2_ceil(softmax_elements);
+ const int next_power_of_two = 1 << log2_elements;
+ int seq_len = softmax_elements;
+ int batch_count = batches * attn_heads * seq_len;
+
+ // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
+ int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+
+ // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
+ int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+ // use 128 threads per block to maximimize gpu utilization
+ constexpr int threads_per_block = 128;
+
+ int warps_per_block = (threads_per_block / warp_size);
+ int batches_per_block = warps_per_block * batches_per_warp;
+ int blocks = batch_count/batches_per_block;
+ dim3 threads(warp_size, warps_per_block, 1);
+ // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+ switch (log2_elements) {
+ case 0: // 1
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 1: // 2
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 2: // 4
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 3: // 8
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 4: // 16
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 5: // 32
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 6: // 64
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 7: // 128
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 8: // 256
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 9: // 512
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 10: // 1024
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 11: // 2048
+ scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ default:
+ break;
+ }
+ }
+}
diff --git a/megatron_lm/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron_lm/megatron/fused_kernels/scaled_masked_softmax_cuda.cu
new file mode 100644
index 0000000..63aaccd
--- /dev/null
+++ b/megatron_lm/megatron/fused_kernels/scaled_masked_softmax_cuda.cu
@@ -0,0 +1,102 @@
+/* coding=utf-8
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <ATen/ATen.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cuda_fp16.h>
+#include <cuda_profiler_api.h>
+#include "THC/THC.h"
+#include <ATen/cuda/CUDAContext.h>
+#include <torch/extension.h>
+#include "scaled_masked_softmax.h"
+
+namespace multihead_attn {
+namespace fused_softmax {
+namespace scaled_masked_softmax {
+
+torch::Tensor fwd_cuda(
+ torch::Tensor const& input,
+ torch::Tensor const& mask,
+ float scale_factor)
+{
+ // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
+ const int batches = input.size(0);
+ const int pad_batches = mask.size(0);
+ const int attn_heads = input.size(1);
+ const int seq_len = input.size(2);
+ TORCH_INTERNAL_ASSERT(seq_len <= 2048);
+ TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
+ TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
+ TORCH_INTERNAL_ASSERT(mask.size(2) == seq_len);
+ TORCH_INTERNAL_ASSERT(mask.size(3) == seq_len);
+
+ // Output
+ auto act_options = input.options().requires_grad(false);
+ torch::Tensor softmax_results =
+ torch::empty({batches, attn_heads, seq_len, seq_len}, act_options);
+
+ // Softmax Intermediate Result Ptr
+ void* input_ptr = static_cast<void*>(input.data_ptr());
+ void* mask_ptr = static_cast<void*>(mask.data_ptr());
+ void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
+
+ dispatch_scaled_masked_softmax_forward<half, half, float>(
+ reinterpret_cast<half*>(softmax_results_ptr),
+ reinterpret_cast<const half*>(input_ptr),
+ reinterpret_cast<const uint8_t*>(mask_ptr),
+ scale_factor,
+ seq_len,
+ seq_len,
+ batches,
+ attn_heads,
+ pad_batches);
+ return softmax_results;
+}
+
+torch::Tensor bwd_cuda(
+ torch::Tensor const& output_grads_,
+ torch::Tensor const& softmax_results_,
+ float scale_factor) {
+
+ auto output_grads = output_grads_.contiguous();
+ auto softmax_results = softmax_results_.contiguous();
+
+ //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
+ const int batches = output_grads.size(0);
+ const int attn_heads = output_grads.size(1);
+ const int seq_len = output_grads.size(2);
+ TORCH_INTERNAL_ASSERT(output_grads.size(2) == output_grads.size(3));
+
+ void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
+
+ //Softmax Grad
+ dispatch_scaled_masked_softmax_backward<half, half, float>(
+ reinterpret_cast<half*>(output_grads_ptr),
+ reinterpret_cast<half*>(output_grads_ptr),
+ reinterpret_cast<half const*>(softmax_results.data_ptr()),
+ scale_factor,
+ seq_len,
+ seq_len,
+ batches,
+ attn_heads);
+
+ //backward pass is completely in-place
+ return output_grads;
+}
+}
+}
+}
diff --git a/megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp b/megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
new file mode 100644
index 0000000..af5a0c5
--- /dev/null
+++ b/megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
@@ -0,0 +1,69 @@
+/* coding=utf-8
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <torch/extension.h>
+#include <vector>
+
+namespace multihead_attn {
+namespace fused_softmax {
+namespace scaled_upper_triang_masked_softmax {
+
+torch::Tensor fwd_cuda(
+ torch::Tensor const& input,
+ float scale_factor);
+
+torch::Tensor bwd_cuda(
+ torch::Tensor const& output_grads,
+ torch::Tensor const& softmax_results,
+ float scale_factor);
+
+torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
+ AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
+ AT_ASSERTM(input.scalar_type() == at::ScalarType::Half,
+ "Only HALF is supported");
+
+ return fwd_cuda(input, scale_factor);
+}
+
+torch::Tensor bwd(
+ torch::Tensor const& output_grads,
+ torch::Tensor const& softmax_results,
+ float scale_factor) {
+
+ AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
+ AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
+
+ AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half,
+ "Only HALF is supported");
+ AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half,
+ "Only HALF is supported");
+
+ return bwd_cuda(output_grads, softmax_results, scale_factor);
+}
+
+} // end namespace scaled_upper_triang_masked_softmax
+} // end namespace fused_softmax
+} // end namespace multihead_attn
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward",
+ &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
+ "Self Multihead Attention scaled, time masked softmax -- Forward.");
+ m.def("backward",
+ &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
+ "Self Multihead Attention scaled, time masked softmax -- Backward.");
+}
diff --git a/megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
new file mode 100644
index 0000000..6f448a3
--- /dev/null
+++ b/megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
@@ -0,0 +1,439 @@
+/* coding=utf-8
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <assert.h>
+#include <cuda_fp16.h>
+#include <cfloat>
+#include <limits>
+#include <stdint.h>
+#include <cuda_fp16.h>
+#include <c10/macros/Macros.h>
+
+namespace {
+
+int log2_ceil(int value) {
+ int log2_value = 0;
+ while ((1 << log2_value) < value) ++log2_value;
+ return log2_value;
+}
+
+template<typename T>
+struct Add {
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return a + b;
+ }
+};
+
+template<typename T>
+struct Max {
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return a < b ? b : a;
+ }
+};
+
+template <typename T>
+__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
+{
+#if CUDA_VERSION >= 9000
+ return __shfl_xor_sync(mask, value, laneMask, width);
+#else
+ return __shfl_xor(value, laneMask, width);
+#endif
+}
+
+template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
+__device__ __forceinline__ void warp_reduce(acc_t* sum) {
+ ReduceOp<acc_t> r;
+ #pragma unroll
+ for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
+ sum[i] = r(sum[i], b);
+ }
+ }
+}
+
+/*
+ * Extended softmax (from native aten pytorch) with following additional features
+ * 1) input scaling
+ * 2) Implicit time (diagonal masking)
+ */
+template <typename input_t, typename output_t, typename acc_t, int log2_elements>
+__global__ void scaled_upper_triang_masked_softmax_warp_forward(
+ output_t *dst,
+ const input_t *src,
+ const acc_t scale,
+ int batch_size,
+ int stride,
+ int element_count)
+{
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
+ // warp_size of method warp_softmax_forward_kernel.
+ constexpr int next_power_of_two = 1 << log2_elements;
+ constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+ constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+ constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+
+ int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
+ int local_seq = blockIdx.x + 1;
+ int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE;
+
+ // batch_size might not be a multiple of WARP_BATCH. Check how
+ // many batches have to computed within this WARP.
+ int local_batches = batch_size - first_batch;
+ if (local_batches > WARP_BATCH)
+ local_batches = WARP_BATCH;
+
+ // there might be multiple batches per warp. compute the index within the batch
+ int local_idx = threadIdx.x;
+
+ src += first_batch * stride + local_idx;
+ dst += first_batch * stride + local_idx;
+
+ // load data from global memory
+ acc_t elements[WARP_BATCH][WARP_ITERATIONS];
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ int batch_element_count = (i >= local_batches) ? 0 : local_seq;
+
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ if (element_index < batch_element_count) {
+ elements[i][it] = (acc_t)src[i*element_count*stride+it*WARP_SIZE] * scale;
+ } else {
+ elements[i][it] = -std::numeric_limits<acc_t>::infinity();
+ }
+ }
+ }
+
+ // compute max_value
+ acc_t max_value[WARP_BATCH];
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ max_value[i] = elements[i][0];
+ #pragma unroll
+ for (int it = 1; it < WARP_ITERATIONS; ++it) {
+ max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
+ }
+ }
+ warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
+
+ acc_t sum[WARP_BATCH] { 0.0f };
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ if (it < warp_iteration_limit) {
+ elements[i][it] = std::exp((elements[i][it] - max_value[i]));
+ sum[i] += elements[i][it];
+ }
+ }
+ }
+ warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
+
+ // store result
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ if (i >= local_batches)
+ break;
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ if (element_index < local_seq) {
+ dst[i*element_count*stride+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]);
+ } else if (element_index < element_count) {
+ dst[i*element_count*stride+it*WARP_SIZE] = 0;
+ } else {
+ break;
+ }
+ }
+ }
+}
+
+template <typename input_t, typename output_t, typename acc_t, int log2_elements>
+__global__ void scaled_upper_triang_masked_softmax_warp_backward(
+ output_t *gradInput,
+ input_t *grad,
+ const input_t *output,
+ acc_t scale,
+ int batch_size,
+ int stride,
+ int element_count)
+{
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
+ // warp_size of method warp_softmax_backward_kernel.
+ constexpr int next_power_of_two = 1 << log2_elements;
+ constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+ constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+ constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+
+ int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
+ int local_seq = blockIdx.x + 1;
+
+ // batch_size might not be a multiple of WARP_BATCH. Check how
+ // many batches have to computed within this WARP.
+ int local_batches = batch_size - first_batch;
+ if (local_batches > WARP_BATCH)
+ local_batches = WARP_BATCH;
+
+ // there might be multiple batches per warp. compute the index within the batch
+ int local_idx = threadIdx.x;
+
+ // the first element to process by the current thread
+ int thread_offset = first_batch * stride + local_idx;
+ grad += thread_offset;
+ output += thread_offset;
+ gradInput += thread_offset;
+
+ // load data from global memory
+ acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
+ acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ int batch_element_count = (i >= local_batches) ? 0 : local_seq;
+
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ if (element_index < batch_element_count) {
+ output_reg[i][it] = output[i*element_count*stride+it*WARP_SIZE];
+ } else {
+ output_reg[i][it] = acc_t(0);
+ }
+ }
+
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ if (element_index < batch_element_count) {
+ grad_reg[i][it] = (acc_t)grad[i*element_count*stride+it*WARP_SIZE] * output_reg[i][it];
+ } else {
+ grad_reg[i][it] = acc_t(0);
+ }
+ }
+ }
+
+ acc_t sum[WARP_BATCH];
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ sum[i] = grad_reg[i][0];
+ #pragma unroll
+ for (int it = 1; it < WARP_ITERATIONS; ++it) {
+ sum[i] += grad_reg[i][it];
+ }
+ }
+ warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
+
+ // store result
+ #pragma unroll
+ for (int i = 0; i < WARP_BATCH; ++i) {
+ if (i >= local_batches)
+ break;
+ #pragma unroll
+ for (int it = 0; it < WARP_ITERATIONS; ++it) {
+ int element_index = local_idx + it * WARP_SIZE;
+ if (element_index < element_count) {
+ // compute gradients
+ gradInput[i*element_count*stride+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i]));
+ }
+ }
+ }
+}
+
+} // end of anonymous namespace
+
+template<typename input_t, typename output_t, typename acc_t>
+void dispatch_scaled_upper_triang_masked_softmax_forward(
+ output_t *dst,
+ const input_t *src,
+ const input_t scale,
+ int softmax_elements,
+ int softmax_elements_stride,
+ int attn_batches)
+{
+ TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
+ if (softmax_elements == 0) {
+ return;
+ } else {
+ int log2_elements = log2_ceil(softmax_elements);
+ const int next_power_of_two = 1 << log2_elements;
+ int seq_len = softmax_elements;
+ int batch_count = attn_batches * seq_len;
+
+ // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
+ int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+
+ // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
+ int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+ // use 128 threads per block to maximimize gpu utilization
+ constexpr int threads_per_block = 128;
+
+ int warps_per_block = (threads_per_block / warp_size);
+ int batches_per_block = warps_per_block * batches_per_warp;
+ TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
+ int blocks_per_seq = attn_batches / batches_per_block;
+ dim3 blocks(seq_len, blocks_per_seq, 1);
+ dim3 threads(warp_size, warps_per_block, 1);
+ // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+ switch (log2_elements) {
+ case 0: // 1
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 1: // 2
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 2: // 4
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 3: // 8
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 4: // 16
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 5: // 32
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 6: // 64
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 7: // 128
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 8: // 256
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 9: // 512
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 10: // 1024
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 11: // 2048
+ scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+template<typename input_t, typename output_t, typename acc_t>
+void dispatch_scaled_upper_triang_masked_softmax_backward(
+ output_t *grad_input,
+ input_t *grad,
+ const input_t *output,
+ const acc_t scale,
+ int softmax_elements,
+ int softmax_elements_stride,
+ int attn_batches)
+{
+ TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
+ if (softmax_elements == 0) {
+ return;
+ } else {
+ int log2_elements = log2_ceil(softmax_elements);
+ const int next_power_of_two = 1 << log2_elements;
+ int seq_len = softmax_elements;
+ int batch_count = attn_batches * seq_len;
+
+ // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
+ int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+
+ // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
+ int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+ // use 128 threads per block to maximimize gpu utilization
+ constexpr int threads_per_block = 128;
+
+ int warps_per_block = (threads_per_block / warp_size);
+ int batches_per_block = warps_per_block * batches_per_warp;
+ TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
+ int blocks_per_seq = attn_batches / batches_per_block;
+ dim3 blocks(seq_len, blocks_per_seq, 1);
+ dim3 threads(warp_size, warps_per_block, 1);
+ // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+ switch (log2_elements) {
+ case 0: // 1
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 1: // 2
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 2: // 4
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 3: // 8
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 4: // 16
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 5: // 32
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 6: // 64
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 7: // 128
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 8: // 256
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 9: // 512
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 10: // 1024
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ case 11: // 2048
+ scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+ break;
+ default:
+ break;
+ }
+ }
+}
diff --git a/megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
new file mode 100644
index 0000000..ffd2757
--- /dev/null
+++ b/megatron_lm/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
@@ -0,0 +1,89 @@
+/* coding=utf-8
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <ATen/ATen.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cuda_fp16.h>
+#include <cuda_profiler_api.h>
+#include "THC/THC.h"
+#include <ATen/cuda/CUDAContext.h>
+#include <torch/extension.h>
+#include "scaled_upper_triang_masked_softmax.h"
+
+namespace multihead_attn {
+namespace fused_softmax {
+namespace scaled_upper_triang_masked_softmax {
+
+torch::Tensor fwd_cuda(
+ torch::Tensor const& input,
+ float scale_factor)
+{
+ // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
+ const int attn_batches = input.size(0);
+ const int seq_len = input.size(1);
+ TORCH_INTERNAL_ASSERT(seq_len <= 2048);
+
+ // Output
+ auto act_options = input.options().requires_grad(false);
+ torch::Tensor softmax_results =
+ torch::empty({attn_batches, seq_len, seq_len}, act_options);
+
+ // Softmax Intermediate Result Ptr
+ void* input_ptr = static_cast<void*>(input.data_ptr());
+ void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
+
+ dispatch_scaled_upper_triang_masked_softmax_forward<half, half, float>(
+ reinterpret_cast<half*>(softmax_results_ptr),
+ reinterpret_cast<const half*>(input_ptr),
+ scale_factor,
+ seq_len,
+ seq_len,
+ attn_batches);
+ return softmax_results;
+}
+
+torch::Tensor bwd_cuda(
+ torch::Tensor const& output_grads_,
+ torch::Tensor const& softmax_results_,
+ float scale_factor) {
+
+ auto output_grads = output_grads_.contiguous();
+ auto softmax_results = softmax_results_.contiguous();
+
+ //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
+ const int attn_batches = output_grads.size(0);
+ const int seq_len = output_grads.size(1);
+ TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
+
+ void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
+
+ //Softmax Grad
+ dispatch_scaled_upper_triang_masked_softmax_backward<half, half, float>(
+ reinterpret_cast<half*>(output_grads_ptr),
+ reinterpret_cast<half*>(output_grads_ptr),
+ reinterpret_cast<half const*>(softmax_results.data_ptr()),
+ scale_factor,
+ seq_len,
+ seq_len,
+ attn_batches);
+
+ //backward pass is completely in-place
+ return output_grads;
+}
+}
+}
+}
diff --git a/megatron_lm/megatron/global_vars.py b/megatron_lm/megatron/global_vars.py
new file mode 100644
index 0000000..8d72a0b
--- /dev/null
+++ b/megatron_lm/megatron/global_vars.py
@@ -0,0 +1,233 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron global variables."""
+
+import os
+import sys
+import time
+
+import torch
+
+from megatron.tokenizer import build_tokenizer
+from .arguments import parse_args
+
+_GLOBAL_ARGS = None
+_GLOBAL_TOKENIZER = None
+_GLOBAL_TENSORBOARD_WRITER = None
+_GLOBAL_ADLR_AUTORESUME = None
+_GLOBAL_TIMERS = None
+
+
+def get_args():
+ """Return arguments."""
+ _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
+ return _GLOBAL_ARGS
+
+
+def get_tokenizer():
+ """Return tokenizer."""
+ _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
+ return _GLOBAL_TOKENIZER
+
+
+def get_tensorboard_writer():
+ """Return tensorboard writer. It can be None so no need
+ to check if it is initialized."""
+ return _GLOBAL_TENSORBOARD_WRITER
+
+
+def get_adlr_autoresume():
+ """ADLR autoresume object. It can be None so no need
+ to check if it is initialized."""
+ return _GLOBAL_ADLR_AUTORESUME
+
+
+def get_timers():
+ """Return timers."""
+ _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
+ return _GLOBAL_TIMERS
+
+
+def set_global_variables(extra_args_provider=None, args_defaults={},
+ ignore_unknown_args=False):
+ """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
+ args = _parse_args(extra_args_provider=extra_args_provider,
+ defaults=args_defaults,
+ ignore_unknown_args=ignore_unknown_args)
+ _ = _build_tokenizer(args)
+ _set_tensorboard_writer(args)
+ _set_adlr_autoresume(args)
+ _set_timers()
+
+
+def _parse_args(extra_args_provider=None, defaults={},
+ ignore_unknown_args=False):
+ """Parse entire arguments."""
+ global _GLOBAL_ARGS
+ _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
+ _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
+ defaults=defaults,
+ ignore_unknown_args=ignore_unknown_args)
+ return _GLOBAL_ARGS
+
+
+def _build_tokenizer(args):
+ """Initialize tokenizer."""
+ global _GLOBAL_TOKENIZER
+ _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
+ _GLOBAL_TOKENIZER = build_tokenizer(args)
+ return _GLOBAL_TOKENIZER
+
+
+def rebuild_tokenizer(args):
+ global _GLOBAL_TOKENIZER
+ _GLOBAL_TOKENIZER = None
+ return _build_tokenizer(args)
+
+
+def _set_tensorboard_writer(args):
+ """Set tensorboard writer."""
+ global _GLOBAL_TENSORBOARD_WRITER
+ _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
+ 'tensorboard writer')
+
+ if hasattr(args, 'tensorboard_dir') and \
+ args.tensorboard_dir and args.rank == 0:
+ try:
+ from torch.utils.tensorboard import SummaryWriter
+ print('> setting tensorboard ...')
+ _GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
+ log_dir=args.tensorboard_dir)
+ except ModuleNotFoundError:
+ print('WARNING: TensorBoard writing requested but is not '
+ 'available (are you using PyTorch 1.1.0 or later?), '
+ 'no TensorBoard logs will be written.', flush=True)
+
+
+def _set_adlr_autoresume(args):
+ """Initialize ADLR autoresume."""
+ global _GLOBAL_ADLR_AUTORESUME
+ _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')
+
+ if args.adlr_autoresume:
+ if args.rank == 0:
+ print('enabling autoresume ...', flush=True)
+ sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
+ try:
+ from userlib.auto_resume import AutoResume
+ except BaseException:
+ print('ADLR autoresume is not available, exiting ...')
+ sys.exit()
+
+ _GLOBAL_ADLR_AUTORESUME = AutoResume
+
+
+def _set_timers():
+ """Initialize timers."""
+ global _GLOBAL_TIMERS
+ _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
+ _GLOBAL_TIMERS = Timers()
+
+
+def _ensure_var_is_initialized(var, name):
+ """Make sure the input variable is not None."""
+ assert var is not None, '{} is not initialized.'.format(name)
+
+
+def _ensure_var_is_not_initialized(var, name):
+ """Make sure the input variable is not None."""
+ assert var is None, '{} is already initialized.'.format(name)
+
+
+class _Timer:
+ """Timer."""
+
+ def __init__(self, name):
+ self.name_ = name
+ self.elapsed_ = 0.0
+ self.started_ = False
+ self.start_time = time.time()
+
+ def start(self):
+ """Start the timer."""
+ assert not self.started_, 'timer has already been started'
+ torch.cuda.synchronize()
+ self.start_time = time.time()
+ self.started_ = True
+
+ def stop(self):
+ """Stop the timer."""
+ assert self.started_, 'timer is not started'
+ torch.cuda.synchronize()
+ self.elapsed_ += (time.time() - self.start_time)
+ self.started_ = False
+
+ def reset(self):
+ """Reset timer."""
+ self.elapsed_ = 0.0
+ self.started_ = False
+
+ def elapsed(self, reset=True):
+ """Calculate the elapsed time."""
+ started_ = self.started_
+ # If the timing in progress, end it first.
+ if self.started_:
+ self.stop()
+ # Get the elapsed time.
+ elapsed_ = self.elapsed_
+ # Reset the elapsed time
+ if reset:
+ self.reset()
+ # If timing was in progress, set it back.
+ if started_:
+ self.start()
+ return elapsed_
+
+
+class Timers:
+ """Group of timers."""
+
+ def __init__(self):
+ self.timers = {}
+
+ def __call__(self, name):
+ if name not in self.timers:
+ self.timers[name] = _Timer(name)
+ return self.timers[name]
+
+ def write(self, names, writer, iteration, normalizer=1.0, reset=False):
+ """Write timers to a tensorboard writer"""
+ # currently when using add_scalars,
+ # torch.utils.add_scalars makes each timer its own run, which
+ # polutes the runs list, so we just add each as a scalar
+ assert normalizer > 0.0
+ for name in names:
+ value = self.timers[name].elapsed(reset=reset) / normalizer
+ writer.add_scalar(name + '_time', value, iteration)
+
+ def log(self, names, normalizer=1.0, reset=True):
+ """Log a group of timers."""
+ assert normalizer > 0.0
+ string = 'time (ms)'
+ for name in names:
+ elapsed_time = self.timers[name].elapsed(
+ reset=reset) * 1000.0 / normalizer
+ string += ' | {}: {:.2f}'.format(name, elapsed_time)
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_rank() == 0:
+ print(string, flush=True)
+ else:
+ print(string, flush=True)
diff --git a/megatron_lm/megatron/indexer.py b/megatron_lm/megatron/indexer.py
new file mode 100644
index 0000000..fd65c3f
--- /dev/null
+++ b/megatron_lm/megatron/indexer.py
@@ -0,0 +1,91 @@
+import torch
+import torch.distributed as dist
+
+from megatron import get_args
+from megatron import mpu
+from megatron.checkpointing import load_ict_checkpoint
+from megatron.data.ict_dataset import get_ict_dataset
+from megatron.data.realm_dataset_utils import get_one_epoch_dataloader
+from megatron.data.realm_index import detach, BlockData
+from megatron.data.realm_dataset_utils import get_ict_batch
+from megatron.model.realm_model import general_ict_model_provider
+from megatron.training import get_model
+
+
+class IndexBuilder(object):
+ """Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
+ def __init__(self):
+ args = get_args()
+ self.model = None
+ self.dataloader = None
+ self.block_data = None
+
+ # need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
+ assert not (args.load and args.ict_load)
+ self.using_realm_chkpt = args.ict_load is None
+
+ self.log_interval = args.indexer_log_interval
+ self.batch_size = args.indexer_batch_size
+
+ self.load_attributes()
+ self.is_main_builder = mpu.get_data_parallel_rank() == 0
+ self.num_total_builders = mpu.get_data_parallel_world_size()
+ self.iteration = self.total_processed = 0
+
+ def load_attributes(self):
+ """Load the necessary attributes: model, dataloader and empty BlockData"""
+ model = get_model(lambda: general_ict_model_provider(only_block_model=True))
+ self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
+ self.model.eval()
+ self.dataset = get_ict_dataset()
+ self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
+ self.block_data = BlockData(load_from_path=False)
+
+ def track_and_report_progress(self, batch_size):
+ """Utility function for tracking progress"""
+ self.iteration += 1
+ self.total_processed += batch_size * self.num_total_builders
+ if self.is_main_builder and self.iteration % self.log_interval == 0:
+ print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True)
+
+ def build_and_save_index(self):
+ """Goes through one epoch of the dataloader and adds all data to this instance's BlockData.
+
+ The copy of BlockData is saved as a shard, which when run in a distributed setting will be
+ consolidated by the rank 0 process and saved as a final pickled BlockData.
+ """
+
+ while True:
+ try:
+ # batch also has query_tokens and query_pad_data
+ _, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
+ except (StopIteration, IndexError):
+ break
+
+ unwrapped_model = self.model
+ while not hasattr(unwrapped_model, 'embed_block'):
+ unwrapped_model = unwrapped_model.module
+
+ # detach, separate fields and add to BlockData
+ block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))
+ detached_data = detach(block_sample_data)
+
+ # block_sample_data is a 2D array [batch x 4]
+ # with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
+ block_indices = detached_data[:, 3]
+ block_metas = detached_data[:, :3]
+
+ self.block_data.add_block_data(block_indices, block_logits, block_metas)
+ self.track_and_report_progress(batch_size=block_tokens.shape[0])
+
+ # This process signals to finalize its shard and then synchronize with the other processes
+ self.block_data.save_shard()
+ torch.distributed.barrier()
+ del self.model
+
+ # rank 0 process builds the final copy
+ if self.is_main_builder:
+ self.block_data.merge_shards_and_save()
+ # make sure that every single piece of data was embedded
+ assert len(self.block_data.embed_data) == len(self.dataset)
+ self.block_data.clear()
diff --git a/megatron_lm/megatron/initialize.py b/megatron_lm/megatron/initialize.py
new file mode 100644
index 0000000..f79c751
--- /dev/null
+++ b/megatron_lm/megatron/initialize.py
@@ -0,0 +1,208 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron initialization."""
+
+import random
+import os
+
+import numpy as np
+import torch
+
+from megatron import get_adlr_autoresume
+from megatron import get_args
+from megatron import get_tensorboard_writer
+from megatron import mpu
+from megatron.global_vars import set_global_variables
+from megatron.mpu import set_model_parallel_rank, set_model_parallel_world_size
+
+import deepspeed
+
+
+def initialize_megatron(extra_args_provider=None, args_defaults={},
+ ignore_unknown_args=False, allow_no_cuda=False):
+ """Set global variables, initialize distributed, and
+ set autoresume and random seeds.
+ `allow_no_cuda` should not be set unless using megatron for cpu only
+ data processing. In general this arg should not be set unless you know
+ what you are doing.
+ Returns a function to finalize distributed env initialization
+ (optionally, only when args.lazy_mpu_init == True)
+
+"""
+ if not allow_no_cuda:
+ # Make sure cuda is available.
+ assert torch.cuda.is_available(), 'Megatron requires CUDA.'
+
+ # Parse args, build tokenizer, and set adlr-autoresume,
+ # tensorboard-writer, and timers.
+ set_global_variables(extra_args_provider=extra_args_provider,
+ args_defaults=args_defaults,
+ ignore_unknown_args=ignore_unknown_args)
+
+ # torch.distributed initialization
+ def finish_mpu_init():
+ args = get_args()
+ # Pytorch distributed.
+ _initialize_distributed()
+
+ # Random seeds for reproducibility.
+ if args.rank == 0:
+ print('> setting random seeds to {} ...'.format(args.seed))
+ _set_random_seed(args.seed)
+
+ args = get_args()
+ if args.lazy_mpu_init:
+ args.use_cpu_initialization=True
+ # delayed initialization of DDP-related stuff
+ # We only set basic DDP globals
+ set_model_parallel_world_size(args.model_parallel_size)
+ # and return function for external DDP manager to call when it has DDP initialized
+ set_model_parallel_rank(args.rank)
+ return finish_mpu_init
+ else:
+ # Megatron's MPU is the master. Complete initialization right away.
+ finish_mpu_init()
+
+ # Initialize memory buffers.
+ _initialize_mem_buffs()
+
+ # Autoresume.
+ _init_autoresume()
+
+ # Write arguments to tensorboard.
+ _write_args_to_tensorboard()
+ # No continuation function
+ return None
+
+
+def setup_deepspeed_random_and_activation_checkpointing(args):
+ '''Optional DeepSpeed Activation Checkpointing features.
+ Gives access to partition activations, contiguous memory optimizations
+ and cpu checkpointing.
+
+ Activation checkpoint requires keep track of the random states
+ and setting the random seed for each MP process. Megatron uses
+ mpu.get_cuda_rng_tracker and mpu.model_parallel_cuda_manual_seed
+ for keeping track of the random states and setting the random seeds.
+ Since they are used in places outside of activation checkpointing,
+ we overwrite them to maintain consistency.
+
+ This must be called before all the calls to mpu.model_parallel_cuda_manual_seed
+ '''
+ num_layers = args.num_layers // args.checkpoint_num_layers
+ num_layers = num_layers if args.num_layers % args.checkpoint_num_layers == 0 else num_layers + 1
+ if args.split_transformers:
+ num_layers *= 2
+
+ deepspeed.checkpointing.configure(
+ mpu,
+ partition_activations=args.partition_activations,
+ contiguous_checkpointing=args.contigious_checkpointing,
+ num_checkpoints=num_layers,
+ checkpoint_in_cpu=args.checkpoint_in_cpu,
+ synchronize=args.synchronize_each_layer,
+ profile=args.profile_backward)
+
+ mpu.checkpoint = deepspeed.checkpointing.checkpoint
+ mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
+ mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed
+
+
+def _initialize_distributed():
+ """Initialize torch.distributed and mpu."""
+ args = get_args()
+
+ device_count = torch.cuda.device_count()
+ if torch.distributed.is_initialized():
+
+ if args.rank == 0:
+ print('torch distributed is already initialized, '
+ 'skipping initialization ...', flush=True)
+ args.rank = torch.distributed.get_rank()
+ args.world_size = torch.distributed.get_world_size()
+
+ else:
+
+ if args.rank == 0:
+ print('> initializing torch distributed ...', flush=True)
+ # Manually set the device ids.
+ if device_count > 0:
+ device = args.rank % device_count
+ if args.local_rank is not None:
+ assert args.local_rank == device, \
+ 'expected local-rank to be the same as rank % device-count.'
+ else:
+ args.local_rank = device
+ torch.cuda.set_device(device)
+ # Call the init process
+ init_method = 'tcp://'
+ master_ip = os.getenv('MASTER_ADDR', 'localhost')
+ master_port = os.getenv('MASTER_PORT', '6000')
+ init_method += master_ip + ':' + master_port
+ torch.distributed.init_process_group(
+ backend=args.distributed_backend,
+ world_size=args.world_size, rank=args.rank,
+ init_method=init_method)
+
+ # Set the model-parallel / data-parallel communicators.
+ if device_count > 0:
+ if mpu.model_parallel_is_initialized():
+ print('model parallel is already initialized')
+ else:
+ mpu.initialize_model_parallel(args.model_parallel_size)
+
+ # Optional DeepSpeed Activation Checkpointing Features
+ #
+ if args.deepspeed and args.deepspeed_activation_checkpointing:
+ setup_deepspeed_random_and_activation_checkpointing(args)
+
+def _init_autoresume():
+ """Set autoresume start time."""
+ autoresume = get_adlr_autoresume()
+ if autoresume:
+ torch.distributed.barrier()
+ autoresume.init()
+ torch.distributed.barrier()
+
+
+def _set_random_seed(seed):
+ """Set random seed for reproducability."""
+ if seed is not None and seed > 0:
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.device_count() > 0:
+ mpu.model_parallel_cuda_manual_seed(seed)
+ else:
+ raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
+
+
+def _write_args_to_tensorboard():
+ """Write arguments to tensorboard."""
+ args = get_args()
+ writer = get_tensorboard_writer()
+ if writer:
+ for arg in vars(args):
+ writer.add_text(arg, str(getattr(args, arg)))
+
+
+def _initialize_mem_buffs():
+ """Initialize manually allocated static memory."""
+ args = get_args()
+
+ # Initialize memory for checkpointed activations.
+ if args.distribute_checkpointed_activations:
+ mpu.init_checkpointed_activations_memory_buffer()
diff --git a/megatron_lm/megatron/learning_rates.py b/megatron_lm/megatron/learning_rates.py
new file mode 100644
index 0000000..19be32b
--- /dev/null
+++ b/megatron_lm/megatron/learning_rates.py
@@ -0,0 +1,154 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Learning rate decay functions."""
+
+import math
+
+from megatron import print_rank_0, get_args
+
+
+class AnnealingLR(object):
+ """Anneals the learning rate."""
+
+ def __init__(self, optimizer, start_lr,
+ warmup_iter, total_iters,
+ decay_style, last_iter, min_lr=0.0,
+ use_checkpoint_lr_scheduler=True,
+ override_lr_scheduler=False):
+ args = get_args()
+ # Class values.
+ self.optimizer = optimizer
+ self.start_lr = start_lr
+ self.min_lr = min_lr
+ self.warmup_iter = warmup_iter
+ self.num_iters = last_iter
+ self.end_iter = total_iters
+ assert self.end_iter > 0
+ self.lr_decay_tokens = args.lr_decay_tokens
+ self.num_tokens = 0
+ self.warmup_tokens = 0
+ self.decay_style = decay_style
+ self.override_lr_scheduler = override_lr_scheduler
+ self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
+ if self.override_lr_scheduler:
+ assert not self.use_checkpoint_lr_scheduler, 'both override and '\
+ 'use-checkpoint are set.'
+ # Set the learning rate
+ self.step(self.num_iters, self.num_tokens)
+
+ print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
+
+ def get_lr(self):
+ """Learning rate decay functions from:
+ https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
+
+ # Warmup.
+ if self.warmup_iter > 0:
+ if self.num_iters == self.warmup_iter and self.lr_decay_tokens is not None:
+ self.warmup_tokens = self.num_tokens
+ if self.num_iters <= self.warmup_iter:
+ return float(self.start_lr) * self.num_iters / self.warmup_iter
+
+ if self.lr_decay_tokens is None:
+ # For any iterations larger than `self.end_iter`, use `self.min_lr`.
+ if self.num_iters > self.end_iter:
+ return self.min_lr
+ # If we are done with the warmup period, use the decay style.
+ current_iter = self.num_iters - self.warmup_iter
+ decay_iter = self.end_iter - self.warmup_iter
+ decay_ratio = float(current_iter) / float(decay_iter)
+ else:
+ if self.num_tokens > self.lr_decay_tokens:
+ return self.min_lr
+ current_tokens = self.num_tokens - self.warmup_tokens
+ decay_tokens = self.lr_decay_tokens - self.warmup_tokens
+ decay_ratio = float(current_tokens) / float(decay_tokens)
+ assert decay_ratio >= 0.0
+ assert decay_ratio <= 1.0
+
+ if self.decay_style == 'linear':
+ lr = self.start_lr * (1.0 - decay_ratio)
+ elif self.decay_style == 'cosine':
+ lr = self.start_lr / 2.0 * (math.cos(
+ math.pi * decay_ratio) + 1)
+ elif self.decay_style == 'exponential':
+ # exp(-0.693) = 1/2
+ lr = self.start_lr * math.exp(-0.693 * decay_ratio)
+ else:
+ lr = self.start_lr
+ return max(lr, self.min_lr)
+
+ def step(self, step_num=None, token_num=None):
+ """Set lr for all parameters groups."""
+ args = get_args()
+ if step_num is None:
+ step_num = self.num_iters + 1
+ if token_num is None:
+ token_num = args.tokens
+ self.num_iters = step_num
+ self.num_tokens = token_num
+ new_lr = self.get_lr()
+ for group in self.optimizer.param_groups:
+ group['lr'] = new_lr
+
+ def state_dict(self):
+ state_dict = {
+ 'start_lr': self.start_lr,
+ 'warmup_iter': self.warmup_iter,
+ 'num_iters': self.num_iters,
+ 'warmup_tokens': self.warmup_tokens,
+ 'num_tokens': self.num_tokens,
+ 'decay_style': self.decay_style,
+ 'end_iter': self.end_iter,
+ 'min_lr': self.min_lr
+ }
+ return state_dict
+
+ def _check_and_set(self, cls_value, sd_value, name):
+ """Auxiliary function for checking the values in the checkpoint and
+ setting them."""
+ if self.override_lr_scheduler:
+ print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
+ return cls_value
+
+ if not self.use_checkpoint_lr_scheduler:
+ assert cls_value == sd_value, 'AnnealingLR: class input value' \
+ 'and checkpoint values for {} do not match'.format(name)
+ print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
+ name))
+ return sd_value
+
+ def load_state_dict(self, sd):
+
+ self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
+ 'learning rate')
+ self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
+ 'minimum learning rate')
+ self.warmup_iter = self._check_and_set(self.warmup_iter,
+ sd['warmup_iter'],
+ 'warmup iterations')
+ self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'],
+ 'total number of iterations')
+ self.decay_style = self._check_and_set(self.decay_style,
+ sd['decay_style'],
+ 'decay style')
+
+ self.num_iters = sd['num_iters']
+ if 'warmup_tokens' in sd:
+ self.warmup_tokens = sd['warmup_tokens']
+ if 'num_tokens' in sd:
+ self.num_tokens = sd['num_tokens']
+ self.step(self.num_iters, self.num_tokens)
diff --git a/megatron_lm/megatron/memory.py b/megatron_lm/megatron/memory.py
new file mode 100644
index 0000000..be5a117
--- /dev/null
+++ b/megatron_lm/megatron/memory.py
@@ -0,0 +1,145 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+
+# A dictionary of all the memory buffers allocated.
+_MEM_BUFFS = dict()
+
+
+def allocate_mem_buff(name, numel, dtype, track_usage):
+ """Allocate a memory buffer."""
+ assert name not in _MEM_BUFFS, \
+ 'memory buffer {} already allocated.'.format(name)
+ _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
+ return _MEM_BUFFS[name]
+
+
+def get_mem_buff(name):
+ """Get the memory buffer."""
+ return _MEM_BUFFS[name]
+
+
+class MemoryBuffer:
+ """Contiguous memory buffer.
+ Allocate a contiguous memory of type `dtype` and size `numel`. It is
+ used to reduce memory fragmentation.
+
+ Usage: After the allocation, the `_start` index is set tot the first
+ index of the memory. A memory chunk starting from `_start` index
+ can be `allocated` for an input tensor, with the elements of the
+ tensor being coppied. The buffer can be reused by resetting the
+ `_start` index.
+
+ """
+ def __init__(self, name, numel, dtype, track_usage):
+ if torch.distributed.get_rank() == 0:
+ element_size = torch.tensor([], dtype=dtype).element_size()
+ print('> building the {} memory buffer with {} num elements '
+ 'and {} dtype ({:.1f} MB)...'.format(
+ name, numel, dtype, numel*element_size/1024/1024),
+ flush=True)
+ self.name = name
+ self.numel = numel
+ self.dtype = dtype
+ self.data = torch.empty(self.numel,
+ dtype=self.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False)
+
+ # Index tracking the start of the free memory.
+ self._start = 0
+
+ # Values used for tracking usage.
+ self.track_usage = track_usage
+ if self.track_usage:
+ self.in_use_value = 0.0
+ self.total_value = 0.0
+
+
+ def reset(self):
+ """Reset the buffer start index to the beginning of the buffer."""
+ self._start = 0
+
+
+ def is_in_use(self):
+ """Whether the current buffer hold on to any memory."""
+ return self._start > 0
+
+
+ def numel_in_use(self):
+ """Return number of elements in use."""
+ return self._start
+
+
+ def add(self, tensor):
+ """Allocate a chunk of memory from the buffer to tensor and copy
+ the values."""
+ assert tensor.dtype == self.dtype, \
+ 'Input tensor type {} different from buffer type {}'.format(
+ tensor.dtype, self.dtype)
+ # Number of elements of the input tensor.
+ tensor_numel = torch.numel(tensor)
+ new_start = self._start + tensor_numel
+ assert new_start <= self.numel, \
+ 'Not enough memory left in the buffer ({} > {})'.format(
+ tensor_numel, self.numel - self._start)
+ # New tensor is a view into the memory.
+ new_tensor = self.data[self._start:new_start]
+ self._start = new_start
+ new_tensor = new_tensor.view(tensor.shape)
+ new_tensor.copy_(tensor)
+ # Return a pointer to the new tensor.
+ return new_tensor
+
+
+ def get_data(self):
+ """Return the data currently in use."""
+ if self.track_usage:
+ self.in_use_value += float(self._start)
+ self.total_value += float(self.numel)
+ return self.data[:self._start]
+
+
+ def print_average_usage(self):
+ """Print memory usage average over time. We would like this value
+ to be as high as possible."""
+ assert self.track_usage, 'You need to enable track usage.'
+ if torch.distributed.get_rank() == 0:
+ print(' > usage of {} memory buffer: {:.2f} %'.format(
+ self.name, self.in_use_value * 100.0 / self.total_value),
+ flush=True)
+
+
+
+class RingMemBuffer:
+ """A ring of memory buffers."""
+
+ def __init__(self, name, num_buffers, numel, dtype, track_usage):
+ self.num_buffers = num_buffers
+ self.buffers = [
+ allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage)
+ for i in range(num_buffers)]
+ self._index = -1
+
+
+ def get_next_buffer(self):
+ self._index += 1
+ self._index = self._index % self.num_buffers
+ buff = self.buffers[self._index]
+ assert not buff.is_in_use(), 'buffer is already in use.'
+ return buff
diff --git a/megatron_lm/megatron/model/__init__.py b/megatron_lm/megatron/model/__init__.py
new file mode 100755
index 0000000..984a104
--- /dev/null
+++ b/megatron_lm/megatron/model/__init__.py
@@ -0,0 +1,21 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .distributed import *
+from .bert_model import BertModel
+from .realm_model import ICTBertModel
+from .gpt2_model import GPT2Model
+from .utils import get_params_for_weight_decay_optimization
+from .language_model import get_language_model
diff --git a/megatron_lm/megatron/model/bert_model.py b/megatron_lm/megatron/model/bert_model.py
new file mode 100644
index 0000000..6b02da7
--- /dev/null
+++ b/megatron_lm/megatron/model/bert_model.py
@@ -0,0 +1,196 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BERT model."""
+
+import torch
+
+from megatron import get_args
+from megatron import mpu
+from megatron.model.language_model import parallel_lm_logits
+from megatron.model.language_model import get_language_model
+from megatron.model.transformer import LayerNorm
+from megatron.model.utils import openai_gelu, erf_gelu
+from megatron.model.utils import get_linear_layer
+from megatron.model.utils import init_method_normal
+from megatron.model.utils import scaled_init_method_normal
+from megatron.module import MegatronModule
+
+def bert_attention_mask_func(attention_scores, attention_mask):
+ attention_scores.masked_fill_(attention_mask, -10000.0)
+ return attention_scores
+
+def bert_extended_attention_mask(attention_mask):
+ # We create a 3D attention mask from a 2D tensor mask.
+ # [b, 1, s]
+ attention_mask_b1s = attention_mask.unsqueeze(1)
+ # [b, s, 1]
+ attention_mask_bs1 = attention_mask.unsqueeze(2)
+ # [b, s, s]
+ attention_mask_bss = attention_mask_b1s * attention_mask_bs1
+ # [b, 1, s, s]
+ extended_attention_mask = attention_mask_bss.unsqueeze(1)
+
+ # Convert attention mask to binary:
+ extended_attention_mask = (extended_attention_mask < 0.5)
+
+ return extended_attention_mask
+
+def bert_position_ids(token_ids):
+ # Create position ids
+ seq_length = token_ids.size(1)
+ position_ids = torch.arange(seq_length, dtype=torch.long,
+ device=token_ids.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
+
+ return position_ids
+
+
+class BertLMHead(MegatronModule):
+ """Masked LM head for Bert
+
+ Arguments:
+ mpu_vocab_size: model parallel size of vocabulary.
+ hidden_size: hidden size
+ init_method: init method for weight initialization
+ layernorm_epsilon: tolerance for layer norm divisions
+ parallel_output: whether output logits being distributed or not.
+ """
+
+ def __init__(self, mpu_vocab_size, hidden_size, init_method,
+ layernorm_epsilon, parallel_output):
+
+ super(BertLMHead, self).__init__()
+
+ args = get_args()
+
+ self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
+ self.bias.model_parallel = True
+ self.bias.partition_dim = 0
+ self.bias.partition_stride = 1
+ self.parallel_output = parallel_output
+
+ self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
+ self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+ self.gelu = torch.nn.functional.gelu
+
+ def forward(self, hidden_states, word_embeddings_weight):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ hidden_states = self.layernorm(hidden_states)
+ output = parallel_lm_logits(hidden_states,
+ word_embeddings_weight,
+ self.parallel_output,
+ bias=self.bias)
+ return output
+
+
+class BertModel(MegatronModule):
+ """Bert Language model."""
+
+ def __init__(self, num_tokentypes=2, add_binary_head=True,
+ parallel_output=True):
+ super(BertModel, self).__init__()
+ args = get_args()
+
+ self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
+ self.add_binary_head = add_binary_head
+ self.parallel_output = parallel_output
+ init_method = init_method_normal(args.init_method_std)
+ scaled_init_method = scaled_init_method_normal(args.init_method_std,
+ args.num_layers)
+
+ self.language_model, self._language_model_key = get_language_model(
+ attention_mask_func=bert_attention_mask_func,
+ num_tokentypes=num_tokentypes,
+ add_pooler=self.add_binary_head,
+ init_method=init_method,
+ scaled_init_method=scaled_init_method)
+
+ self.lm_head = BertLMHead(
+ self.language_model.embedding.word_embeddings.weight.size(0),
+ args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
+ self._lm_head_key = 'lm_head'
+ if self.add_binary_head:
+ self.binary_head = get_linear_layer(args.hidden_size, 2,
+ init_method)
+ self._binary_head_key = 'binary_head'
+
+ def forward(self, input_ids, attention_mask,
+ tokentype_ids=None, lm_labels=None):
+
+ extended_attention_mask = bert_extended_attention_mask(attention_mask)
+ position_ids = bert_position_ids(input_ids)
+
+ if self.add_binary_head:
+ lm_output, pooled_output = self.language_model(
+ input_ids,
+ position_ids,
+ extended_attention_mask,
+ tokentype_ids=tokentype_ids)
+ else:
+ lm_output = self.language_model(
+ input_ids,
+ position_ids,
+ extended_attention_mask,
+ tokentype_ids=tokentype_ids)
+
+ # Output.
+ lm_logits = self.lm_head(
+ lm_output, self.language_model.embedding.word_embeddings.weight)
+
+ binary_logits = None
+ if self.add_binary_head:
+ binary_logits = self.binary_head(pooled_output)
+
+ if lm_labels is None:
+ return lm_logits, binary_logits
+ else:
+ if self.fp16_lm_cross_entropy:
+ assert lm_logits.dtype == torch.half
+ lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
+ else:
+ lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
+ lm_labels)
+ return lm_loss, binary_logits
+
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """For easy load when model is combined with other heads,
+ add an extra key."""
+
+ state_dict_ = {}
+ state_dict_[self._language_model_key] \
+ = self.language_model.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ state_dict_[self._lm_head_key] \
+ = self.lm_head.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ if self.add_binary_head:
+ state_dict_[self._binary_head_key] \
+ = self.binary_head.state_dict(destination, prefix, keep_vars)
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ self.language_model.load_state_dict(
+ state_dict[self._language_model_key], strict=strict)
+ self.lm_head.load_state_dict(
+ state_dict[self._lm_head_key], strict=strict)
+ if self.add_binary_head:
+ self.binary_head.load_state_dict(
+ state_dict[self._binary_head_key], strict=strict)
diff --git a/megatron_lm/megatron/model/classification.py b/megatron_lm/megatron/model/classification.py
new file mode 100644
index 0000000..5c69d95
--- /dev/null
+++ b/megatron_lm/megatron/model/classification.py
@@ -0,0 +1,98 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Classification model."""
+
+import torch
+
+from megatron import get_args, print_rank_0
+from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
+from megatron.model.language_model import get_language_model
+from megatron.model.utils import get_linear_layer
+from megatron.model.utils import init_method_normal
+from megatron.model.utils import scaled_init_method_normal
+from megatron.module import MegatronModule
+
+
+class Classification(MegatronModule):
+
+ def __init__(self, num_classes, num_tokentypes=2):
+ super(Classification, self).__init__()
+ args = get_args()
+
+ self.num_classes = num_classes
+ init_method = init_method_normal(args.init_method_std)
+
+ self.language_model, self._language_model_key = get_language_model(
+ attention_mask_func=bert_attention_mask_func,
+ num_tokentypes=num_tokentypes,
+ add_pooler=True,
+ init_method=init_method,
+ scaled_init_method=scaled_init_method_normal(args.init_method_std,
+ args.num_layers))
+
+ # Multi-choice head.
+ self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
+ self.classification_head = get_linear_layer(args.hidden_size,
+ self.num_classes,
+ init_method)
+ self._classification_head_key = 'classification_head'
+
+ def forward(self, input_ids, attention_mask, tokentype_ids):
+
+ extended_attention_mask = bert_extended_attention_mask(
+ attention_mask, next(self.language_model.parameters()).dtype)
+ position_ids = bert_position_ids(input_ids)
+
+ _, pooled_output = self.language_model(input_ids,
+ position_ids,
+ extended_attention_mask,
+ tokentype_ids=tokentype_ids)
+
+ # Output.
+ classification_output = self.classification_dropout(pooled_output)
+ classification_logits = self.classification_head(classification_output)
+
+ # Reshape back to separate choices.
+ classification_logits = classification_logits.view(-1, self.num_classes)
+
+ return classification_logits
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """For easy load when model is combined with other heads,
+ add an extra key."""
+
+ state_dict_ = {}
+ state_dict_[self._language_model_key] \
+ = self.language_model.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ state_dict_[self._classification_head_key] \
+ = self.classification_head.state_dict(
+ destination, prefix, keep_vars)
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ self.language_model.load_state_dict(
+ state_dict[self._language_model_key], strict=strict)
+ if self._classification_head_key in state_dict:
+ self.classification_head.load_state_dict(
+ state_dict[self._classification_head_key], strict=strict)
+ else:
+ print_rank_0('***WARNING*** could not find {} in the checkpoint, '
+ 'initializing to random'.format(
+ self._classification_head_key))
diff --git a/megatron_lm/megatron/model/distributed.py b/megatron_lm/megatron/model/distributed.py
new file mode 100755
index 0000000..d49cb96
--- /dev/null
+++ b/megatron_lm/megatron/model/distributed.py
@@ -0,0 +1,112 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+import torch.distributed as dist
+from torch.nn.modules import Module
+from torch.autograd import Variable
+
+from megatron import mpu
+from megatron.module import MegatronModule
+
+
+class DistributedDataParallel(MegatronModule):
+
+ def __init__(self, module):
+ super(DistributedDataParallel, self).__init__()
+ self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
+
+ self.module = module
+ self.data_parallel_group = mpu.get_data_parallel_group()
+
+ def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
+ if(self.needs_reduction):
+ self.needs_reduction = False
+ buckets = {}
+ for name, param in self.module.named_parameters():
+ if param.requires_grad and param.grad is not None:
+ tp = (param.data.type())
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(param)
+ if self.warn_on_half:
+ if torch.cuda.HalfTensor in buckets:
+ print("WARNING: gloo dist backend for half parameters may be extremely slow." +
+ " It is recommended to use the NCCL backend in this case.")
+ self.warn_on_half = False
+ for tp in buckets:
+ bucket = buckets[tp]
+ grads = [param.grad.data for param in bucket]
+ coalesced = _flatten_dense_tensors(grads)
+ if fp32_allreduce:
+ coalesced = coalesced.float()
+ if not no_scale and not reduce_after:
+ coalesced /= dist.get_world_size(group=self.data_parallel_group)
+ dist.all_reduce(coalesced, group=self.data_parallel_group)
+ torch.cuda.synchronize()
+ if not no_scale and reduce_after:
+ coalesced /= dist.get_world_size(group=self.data_parallel_group)
+ for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
+ buf.copy_(synced)
+ self.hook_handles = []
+ self.hooks = []
+ for param in list(self.module.parameters()):
+ def allreduce_hook(*unused):
+ Variable._execution_engine.queue_callback(allreduce_params)
+ # handle = param.register_hook(allreduce_hook)
+ # self.hooks.append(allreduce_hook)
+ # self.hook_handles.append(handle)
+ self.allreduce_params = allreduce_params
+
+ def forward(self, *inputs, **kwargs):
+ self.needs_reduction = True
+ return self.module(*inputs, **kwargs)
+
+ def state_dict(self, destination=None, prefix='', keep_vars=False):
+ #[h.remove() for h in self.hook_handles]
+ sd = self.module.state_dict(destination, prefix, keep_vars)
+ # for handle, hook in zip(self.hook_handles, self.hooks):
+ # d = handle.hooks_dict_ref()
+ # d[handle.id] = hook
+
+ return sd
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ return self.module.state_dict_for_save_checkpoint(destination, prefix,
+ keep_vars)
+
+ def load_state_dict(self, state_dict, strict=True):
+ self.module.load_state_dict(state_dict, strict=strict)
+
+ '''
+ def _sync_buffers(self):
+ buffers = list(self.module._all_buffers())
+ if len(buffers) > 0:
+ # cross-node buffer sync
+ flat_buffers = _flatten_dense_tensors(buffers)
+ dist.broadcast(flat_buffers, 0)
+ for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
+ buf.copy_(synced)
+ def train(self, mode=True):
+ # Clear NCCL communicator and CUDA event cache of the default group ID,
+ # These cache will be recreated at the later call. This is currently a
+ # work-around for a potential NCCL deadlock.
+ if dist._backend == dist.dist_backend.NCCL:
+ dist._clear_group_cache()
+ super(DistributedDataParallel, self).train(mode)
+ self.module.train(mode)
+ '''
diff --git a/megatron_lm/megatron/model/fused_bias_gelu.py b/megatron_lm/megatron/model/fused_bias_gelu.py
new file mode 100644
index 0000000..8e17a30
--- /dev/null
+++ b/megatron_lm/megatron/model/fused_bias_gelu.py
@@ -0,0 +1,60 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+torch._C._jit_set_profiling_mode(False)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_override_can_fuse_on_cpu(True)
+torch._C._jit_override_can_fuse_on_gpu(True)
+
+###### BIAS GELU FUSION/ NO AUTOGRAD ################
+# 1/sqrt(2*pi)-> 0.3989423
+# 1/sqrt(2) -> 0.70710678
+# sqrt(2/pi) -> 0.79788456
+# this function is tanh approximation of gelu
+# actual gelu is:
+# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
+
+@torch.jit.script
+def bias_gelu(bias, y):
+ x = bias + y
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
+
+# gradient of tanh approximation of gelu
+# gradient of actual gelu is:
+# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
+@torch.jit.script
+def bias_gelu_back(g, bias, y):
+ x = bias + y
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
+ return ff*g
+
+class GeLUFunction(torch.autograd.Function):
+ @staticmethod
+ # bias is an optional argument
+ def forward(ctx, input, bias):
+ ctx.save_for_backward(input, bias)
+ return bias_gelu(bias, input)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, bias = ctx.saved_tensors
+ tmp = bias_gelu_back(grad_output, bias, input)
+ return tmp, tmp
+
+bias_gelu_impl = GeLUFunction.apply
diff --git a/megatron_lm/megatron/model/fused_softmax.py b/megatron_lm/megatron/model/fused_softmax.py
new file mode 100644
index 0000000..d5cf992
--- /dev/null
+++ b/megatron_lm/megatron/model/fused_softmax.py
@@ -0,0 +1,127 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
+ """
+ Fused operation which performs following three operations in sequence
+ 1. Scale the tensor.
+ 2. Apply upper triangular mask (typically used in gpt models).
+ 3. Perform softmax.
+ """
+ @staticmethod
+ def forward(ctx, inputs, scale):
+ import scaled_upper_triang_masked_softmax_cuda
+ scale_t = torch.tensor([scale])
+
+ softmax_results = \
+ scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
+ ctx.save_for_backward(softmax_results, scale_t)
+ return softmax_results
+
+ @staticmethod
+ def backward(ctx, output_grads):
+ import scaled_upper_triang_masked_softmax_cuda
+ softmax_results, scale_t = ctx.saved_tensors
+
+ input_grads = \
+ scaled_upper_triang_masked_softmax_cuda.backward(output_grads,
+ softmax_results,
+ scale_t[0])
+ return input_grads, None
+
+class ScaledMaskedSoftmax(torch.autograd.Function) :
+ """
+ Fused operation which performs following three operations in sequence
+ 1. Scale the tensor.
+ 2. Apply the mask.
+ 3. Perform softmax.
+ """
+ @staticmethod
+ def forward(ctx, inputs, mask, scale):
+ import scaled_masked_softmax_cuda
+ scale_t = torch.tensor([scale])
+
+ softmax_results = \
+ scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
+ ctx.save_for_backward(softmax_results, scale_t)
+ return softmax_results
+
+ @staticmethod
+ def backward(ctx, output_grads):
+ import scaled_masked_softmax_cuda
+ softmax_results, scale_t = ctx.saved_tensors
+
+ input_grads = \
+ scaled_masked_softmax_cuda.backward(output_grads,
+ softmax_results,
+ scale_t[0])
+ return input_grads, None, None
+
+class FusedScaleMaskSoftmax(torch.nn.Module):
+ """
+ fused operation: scaling + mask + softmax
+ Arguments:
+ input_in_fp16: flag to indicate if input in fp16 data format.
+ upper_triang_mask: if true, apply upper triangular masking.
+ (used in gpt family networks)
+ mask_func: mask function to be applied.
+ softmax_in_fp32: if true, softmax in performed at fp32 precision.
+ scale: scaling factor used in input tensor scaling.
+
+ """
+ def __init__(self, input_in_fp16, upper_triang_mask_fusion,
+ general_mask_fusion, mask_func, softmax_in_fp32, scale):
+ super(FusedScaleMaskSoftmax, self).__init__()
+ self.input_in_fp16 = input_in_fp16
+ self.upper_triang_mask_fusion = upper_triang_mask_fusion
+ self.general_mask_fusion = general_mask_fusion
+ self.mask_func = mask_func
+ self.softmax_in_fp32 = softmax_in_fp32
+ self.scale = scale
+
+ assert self.scale is None or softmax_in_fp32, \
+ 'softmax should be in fp32 when scaled'
+
+ def forward(self, input, mask):
+ # [b, np, s, s]
+ data_size = input.size()
+ assert input.dim() == 4
+
+ # invoke custom kernel
+ if self.input_in_fp16 and data_size[-1] <= 2048 and \
+ (self.upper_triang_mask_fusion or self.general_mask_fusion) and \
+ input.size()[2] == input.size()[3]:
+ scale = self.scale if self.scale is not None else 1.0
+ if self.upper_triang_mask_fusion:
+ input = input.view(-1, data_size[2], data_size[3])
+ probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
+ probs = probs.view(*data_size)
+ else:
+ probs = ScaledMaskedSoftmax.apply(input, mask, scale)
+ else:
+ if self.input_in_fp16 and self.softmax_in_fp32:
+ input = input.float()
+
+ if self.scale is not None:
+ input = input * self.scale
+ mask_output = self.mask_func(input, mask)
+ probs = torch.nn.Softmax(dim=-1)(mask_output)
+
+ if self.input_in_fp16 and self.softmax_in_fp32:
+ probs = probs.half()
+
+ return probs
diff --git a/megatron_lm/megatron/model/gpt2_model.py b/megatron_lm/megatron/model/gpt2_model.py
new file mode 100644
index 0000000..579d897
--- /dev/null
+++ b/megatron_lm/megatron/model/gpt2_model.py
@@ -0,0 +1,125 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""GPT-2 model."""
+
+import torch
+
+from megatron import get_args
+from megatron import mpu
+from megatron.module import MegatronModule
+
+from .language_model import parallel_lm_logits
+from .language_model import get_language_model
+from .utils import init_method_normal
+from .utils import scaled_init_method_normal
+
+import deepspeed
+
+def gpt2_attention_mask_func(attention_scores, ltor_mask):
+ attention_scores.masked_fill_(ltor_mask, -10000.0)
+ return attention_scores
+
+
+class GPT2Model(MegatronModule):
+ """GPT-2 Language model."""
+
+ def __init__(self, num_tokentypes=0, parallel_output=True):
+ super(GPT2Model, self).__init__()
+ args = get_args()
+
+ self.parallel_output = parallel_output
+ self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
+
+ self.language_model, self._language_model_key = get_language_model(
+ attention_mask_func=gpt2_attention_mask_func,
+ num_tokentypes=num_tokentypes,
+ add_pooler=False,
+ init_method=init_method_normal(args.init_method_std),
+ scaled_init_method=scaled_init_method_normal(args.init_method_std,
+ args.num_layers))
+
+
+ def forward(self, input_ids, position_ids, attention_mask, labels=None,
+ tokentype_ids=None, layer_past=None, get_key_value=False,
+ forward_method_parallel_output=None, curriculum_seqlen=None):
+ args = get_args()
+ if curriculum_seqlen is not None:
+ args.curriculum_seqlen = curriculum_seqlen
+ if curriculum_seqlen < input_ids.size()[1]:
+ # seqlen-based curriculum learning
+ # input_ids, position_ids, labels have size [batch size, seqlen]
+ input_ids = input_ids[:, :curriculum_seqlen].contiguous()
+ position_ids = position_ids[:, :curriculum_seqlen].contiguous()
+ labels = labels[:, :curriculum_seqlen].contiguous()
+
+ # attention_mask has size [1, 1, seqlen, seqlen]
+ attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous()
+ else:
+ if hasattr(args, "curriculum_learning") and args.curriculum_learning: # fix
+ # If got a None input, need to reset curriculum_seqlen on user side
+ args.curriculum_seqlen = args.seq_length
+
+ # Language model.
+ lm_output = self.language_model(input_ids,
+ position_ids,
+ attention_mask,
+ tokentype_ids=tokentype_ids,
+ layer_past=layer_past,
+ get_key_value=get_key_value)
+
+ if get_key_value:
+ lm_output, presents = lm_output
+
+ # Output.
+ parallel_output = self.parallel_output
+ if forward_method_parallel_output is not None:
+ parallel_output = forward_method_parallel_output
+
+ output = parallel_lm_logits(
+ lm_output[0],
+ self.language_model.embedding.word_embeddings.weight,
+ parallel_output,
+ bias=lm_output[1])
+
+ if get_key_value:
+ output = [output, presents]
+
+ if labels is None:
+ return output
+ else:
+ if self.fp16_lm_cross_entropy:
+ assert output.dtype == torch.half
+ loss = mpu.vocab_parallel_cross_entropy(output, labels)
+ else:
+ loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
+ return loss
+
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+
+ state_dict_ = {}
+ state_dict_[self._language_model_key] \
+ = self.language_model.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ if self._language_model_key in state_dict:
+ state_dict = state_dict[self._language_model_key]
+ self.language_model.load_state_dict(state_dict, strict=strict)
diff --git a/megatron_lm/megatron/model/language_model.py b/megatron_lm/megatron/model/language_model.py
new file mode 100644
index 0000000..5f89834
--- /dev/null
+++ b/megatron_lm/megatron/model/language_model.py
@@ -0,0 +1,503 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Transformer based language model."""
+
+import torch
+import torch.nn.functional as F
+
+from megatron import get_args
+from megatron import mpu
+from megatron.module import MegatronModule
+from megatron.model.transformer import ParallelTransformer, LayerNorm
+from megatron.model.utils import get_linear_layer
+from megatron.model.utils import init_method_normal, scaled_init_method_normal
+
+import deepspeed
+
+def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
+ bias=None):
+ """LM logits using word embedding weights."""
+ # Parallel logits.
+ input_parallel = mpu.copy_to_model_parallel_region(input_)
+ # Matrix multiply.
+ if bias is None:
+ logits_parallel = F.linear(input_parallel, word_embeddings_weight)
+ else:
+ logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
+ # Gather if needed.
+ if parallel_output:
+ return logits_parallel
+
+ return mpu.gather_from_model_parallel_region(logits_parallel)
+
+
+def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
+ init_method=None, scaled_init_method=None):
+ """Build language model and return along with the key to save."""
+ args = get_args()
+
+ if init_method is None:
+ init_method = init_method_normal(args.init_method_std)
+
+ if scaled_init_method is None:
+ scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
+
+ # Language model.
+ language_model = TransformerLanguageModel(
+ attention_mask_func=attention_mask_func,
+ init_method=init_method,
+ output_layer_init_method=scaled_init_method,
+ num_tokentypes=num_tokentypes,
+ add_pooler=add_pooler)
+ # key used for checkpoints.
+ language_model_key = 'language_model'
+
+ return language_model, language_model_key
+
+
+class Pooler(MegatronModule):
+ """Pooler layer.
+
+ Pool hidden states of a specific token (for example start of the
+ sequence) and add a linear transformation followed by a tanh.
+
+ Arguments:
+ hidden_size: hidden size
+ init_method: weight initialization method for the linear layer.
+ bias is set to zero.
+ """
+
+ def __init__(self, hidden_size, init_method):
+ super(Pooler, self).__init__()
+ self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
+
+ def forward(self, hidden_states, sequence_index=0):
+ # hidden_states: [b, s, h]
+ # sequence_index: index of the token to pool.
+ pooled = hidden_states[:, sequence_index, :]
+ pooled = self.dense(pooled)
+ pooled = torch.tanh(pooled)
+ return pooled
+
+
+class Embedding(MegatronModule):
+ """Language model embeddings.
+
+ Arguments:
+ embedding_size: hidden size
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ init_method: weight initialization method
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ scattered_embeddings: perform elementwise-operations on
+ partitioned embedding activations.
+ introduces minor dropout differences
+ betwen MP configurations.
+ """
+
+ def __init__(self,
+ embedding_size,
+ vocab_size,
+ max_sequence_length,
+ pos_encoding_type,
+ embedding_dropout_prob,
+ init_method,
+ num_tokentypes=0,
+ scattered_embeddings=False):
+ super(Embedding, self).__init__()
+
+ self.embedding_size = embedding_size
+ self.init_method = init_method
+ self.num_tokentypes = num_tokentypes
+ self.scattered_embeddings = scattered_embeddings
+ self.pos_encoding_type = pos_encoding_type
+
+ # Word embeddings (parallel).
+ self.word_embeddings = mpu.VocabParallelEmbedding(
+ vocab_size,
+ self.embedding_size,
+ init_method=self.init_method
+ )
+ self._word_embeddings_key = 'word_embeddings'
+
+ # Position embedding (serial).
+ if pos_encoding_type == 'trainable_absolute':
+ self.position_embeddings = torch.nn.Embedding(
+ max_sequence_length, self.embedding_size)
+ self._position_embeddings_key = 'position_embeddings'
+
+ with deepspeed.zero.GatheredParameters(self.position_embeddings.weight,
+ modifier_rank=0):
+ # Initialize the position embeddings.
+ self.init_method(self.position_embeddings.weight)
+
+ # Token type embedding.
+ # Add this as an optional field that can be added through
+ # method call so we can load a pretrain model without
+ # token types and add them as needed.
+ self._tokentype_embeddings_key = 'tokentype_embeddings'
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
+ self.embedding_size)
+ with deepspeed.zero.GatheredParameters(self.tokentype_embeddings.weight,
+ modifier_rank=0):
+ # Initialize the token-type embeddings.
+ self.init_method(self.tokentype_embeddings.weight)
+ else:
+ self.tokentype_embeddings = None
+
+ # Embeddings dropout
+ self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+
+ def add_tokentype_embeddings(self, num_tokentypes):
+ """Add token-type embedding. This function is provided so we can add
+ token-type embeddings in case the pretrained model does not have it.
+ This allows us to load the model normally and then add this embedding.
+ """
+ if self.tokentype_embeddings is not None:
+ raise Exception('tokentype embeddings is already initialized')
+ if torch.distributed.get_rank() == 0:
+ print('adding embedding for {} tokentypes'.format(num_tokentypes),
+ flush=True)
+ self.num_tokentypes = num_tokentypes
+ self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
+ self.embedding_size)
+ with deepspeed.zero.GatheredParameters(self.tokentype_embeddings.weight,
+ modifier_rank=0):
+ # Initialize the token-type embeddings.
+ self.init_method(self.tokentype_embeddings.weight)
+
+ def forward(self, input_ids, position_ids, tokentype_ids=None):
+ if self.scattered_embeddings:
+ scatter = mpu.scatter_to_model_parallel_region
+ gather = mpu.gather_from_model_parallel_region
+ else:
+ # do nothing
+ scatter = lambda x: x
+ gather = lambda x: x
+
+ # Embeddings.
+ words_embeddings = scatter(self.word_embeddings(input_ids))
+
+ if self.pos_encoding_type == 'trainable_absolute':
+ position_embeddings = scatter(self.position_embeddings(position_ids))
+ embeddings = words_embeddings + position_embeddings
+ else:
+ embeddings = words_embeddings
+
+ if tokentype_ids is not None:
+ assert self.tokentype_embeddings is not None
+ embeddings = embeddings + scatter(self.tokentype_embeddings(tokentype_ids))
+ else:
+ assert self.tokentype_embeddings is None
+
+ # Dropout.
+ embeddings = gather(self.embedding_dropout(embeddings))
+
+ return embeddings
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """For easy load."""
+
+ state_dict_ = {}
+ state_dict_[self._word_embeddings_key] \
+ = self.word_embeddings.state_dict(destination, prefix, keep_vars)
+ if self.pos_encoding_type == 'trainable_absolute':
+ state_dict_[self._position_embeddings_key] \
+ = self.position_embeddings.state_dict(
+ destination, prefix, keep_vars)
+ if self.num_tokentypes > 0:
+ state_dict_[self._tokentype_embeddings_key] \
+ = self.tokentype_embeddings.state_dict(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ # Word embedding.
+ if self._word_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._word_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'word_embeddings' in key:
+ state_dict_[key.split('word_embeddings.')[1]] \
+ = state_dict[key]
+ self.word_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Position embedding.
+ if self._position_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._position_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'position_embeddings' in key:
+ state_dict_[key.split('position_embeddings.')[1]] \
+ = state_dict[key]
+ self.position_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Tokentype embedding.
+ if self.num_tokentypes > 0:
+ state_dict_ = {}
+ if self._tokentype_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._tokentype_embeddings_key]
+ else:
+ # for backward compatibility.
+ for key in state_dict.keys():
+ if 'tokentype_embeddings' in key:
+ state_dict_[key.split('tokentype_embeddings.')[1]] \
+ = state_dict[key]
+ if len(state_dict_.keys()) > 0:
+ self.tokentype_embeddings.load_state_dict(state_dict_,
+ strict=strict)
+ else:
+ print('***WARNING*** expected tokentype embeddings in the '
+ 'checkpoint but could not find it', flush=True)
+
+
+class Projector(MegatronModule):
+ def __init__(self):
+ super().__init__()
+
+ args = get_args()
+ self.embedding_size = args.embedding_size
+ self.hidden_size = args.hidden_size
+ self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
+
+ if not self.apply_residual_connection_post_layernorm:
+ self.input_layernorm = LayerNorm(
+ args.embedding_size,
+ eps=args.layernorm_epsilon
+ )
+
+ if args.embedding_size != args.hidden_size:
+ self.register_buffer(
+ "projector",
+ torch.eye(args.embedding_size, args.hidden_size).to(args.params_dtype),
+ persistent=False,
+ )
+
+ def forward(self, data):
+ if self.apply_residual_connection_post_layernorm:
+ hidden_states = data
+ else:
+ hidden_states = self.input_layernorm(data)
+
+ if self.embedding_size != self.hidden_size:
+ hidden_states = hidden_states @ self.projector
+
+ return hidden_states
+
+
+class OutputLayer(MegatronModule):
+ def __init__(self, init_method):
+ super().__init__()
+ args = get_args()
+
+ self.input_layer_norm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon
+ )
+
+ self.dense = mpu.RowParallelLinear(
+ args.hidden_size,
+ args.embedding_size,
+ input_is_parallel=False,
+ init_method=init_method,
+ skip_bias_add=False,
+ )
+
+ self.activation_func = F.gelu
+
+ self.output_layer_norm = LayerNorm(
+ args.embedding_size,
+ eps=args.layernorm_epsilon
+ )
+
+ self.output_bias = torch.nn.Parameter(
+ torch.zeros(
+ mpu.divide(
+ args.padded_vocab_size,
+ mpu.get_model_parallel_world_size(),
+ )
+ )
+ )
+
+
+ def forward(self, input_data):
+ if isinstance(input_data, torch.Tensor):
+ hidden_states = input_data
+ else:
+ assert len(input_data) == 2, f"Unknown presents format, output of transformer of len {len(input_data)} is {input_data}"
+ hidden_states = input_data[0]
+ presents = input_data[1]
+
+ output = self.input_layer_norm(hidden_states)
+ output, _ = self.dense(output)
+ output = self.activation_func(output)
+ output = self.output_layer_norm(output)
+ output = [output, self.output_bias]
+
+ if isinstance(input_data, torch.Tensor):
+ return output
+ else:
+ return [output, presents]
+
+
+class TransformerLanguageModel(MegatronModule):
+ """Transformer language model.
+
+ Arguments:
+ transformer_hparams: transformer hyperparameters
+ attention_mask_func: a function that takes `unmaksed-attention-scores`
+ with size [b, np, s, s] and an `attention-mask` and will apply
+ the masking. The function should return a masked score of the
+ same size [b, np, s, s].
+ masked-attention-scores = attention_mask_func(
+ unmaksed-attention-scores, attention-mask)
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(self,
+ attention_mask_func,
+ init_method,
+ output_layer_init_method,
+ num_tokentypes=0,
+ add_pooler=False):
+ super(TransformerLanguageModel, self).__init__()
+ args = get_args()
+
+ self.embedding_size = args.embedding_size
+ self.hidden_size = args.hidden_size
+ self.num_tokentypes = num_tokentypes
+ self.init_method = init_method
+ self.add_pooler = add_pooler
+
+ # Embeddings
+ self.embedding = Embedding(self.embedding_size,
+ args.padded_vocab_size,
+ args.max_position_embeddings,
+ args.pos_encoding_type,
+ args.hidden_dropout,
+ self.init_method,
+ self.num_tokentypes,
+ scattered_embeddings=args.scattered_embeddings)
+ self._embedding_key = 'embedding'
+
+ self.projector = Projector()
+
+ # Transformer
+ self.transformer = ParallelTransformer(
+ attention_mask_func, self.init_method,
+ output_layer_init_method)
+ self._transformer_key = 'transformer'
+
+ self.output_layer = OutputLayer(init_method=self.init_method)
+
+ # Pooler
+ if self.add_pooler:
+ self.pooler = Pooler(self.hidden_size, self.init_method)
+ self._pooler_key = 'pooler'
+
+ def forward(self, input_ids, position_ids, attention_mask,
+ tokentype_ids=None, layer_past=None, get_key_value=False,
+ pooling_sequence_index=0):
+
+ # Embeddings.
+ embedding_output = self.embedding(input_ids, position_ids,
+ tokentype_ids=tokentype_ids)
+
+ # Projector!
+ projector_output = self.projector(embedding_output)
+
+ # Transformer.
+ transformer_output = self.transformer(projector_output,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=get_key_value)
+
+ # OutputLayer!
+ output = self.output_layer(transformer_output)
+
+ if self.add_pooler:
+ pooled_output = self.pooler(output,
+ pooling_sequence_index)
+ return output, pooled_output
+
+ return output
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """For easy load."""
+
+ state_dict_ = {}
+ state_dict_[self._embedding_key] \
+ = self.embedding.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ state_dict_[self._transformer_key] \
+ = self.transformer.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ if self.add_pooler:
+ state_dict_[self._pooler_key] \
+ = self.pooler.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ # Embedding.
+ if self._embedding_key in state_dict:
+ state_dict_ = state_dict[self._embedding_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if '_embeddings' in key:
+ state_dict_[key] = state_dict[key]
+ self.embedding.load_state_dict(state_dict_, strict=strict)
+
+ # Transformer.
+ if self._transformer_key in state_dict:
+ state_dict_ = state_dict[self._transformer_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'transformer.' in key:
+ state_dict_[key.split('transformer.')[1]] = state_dict[key]
+ self.transformer.load_state_dict(state_dict_, strict=strict)
+
+ # Pooler.
+ if self.add_pooler:
+ assert 'pooler' in state_dict, \
+ 'could not find data for pooler in the checkpoint'
+ self.pooler.load_state_dict(state_dict[self._pooler_key],
+ strict=strict)
diff --git a/megatron_lm/megatron/model/multiple_choice.py b/megatron_lm/megatron/model/multiple_choice.py
new file mode 100644
index 0000000..97de025
--- /dev/null
+++ b/megatron_lm/megatron/model/multiple_choice.py
@@ -0,0 +1,110 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Multiple choice model."""
+
+import torch
+
+from megatron import get_args, print_rank_0
+from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
+from megatron.model.language_model import get_language_model
+from megatron.model.utils import get_linear_layer
+from megatron.model.utils import init_method_normal
+from megatron.model.utils import scaled_init_method_normal
+from megatron.module import MegatronModule
+
+
+class MultipleChoice(MegatronModule):
+
+ def __init__(self, num_tokentypes=2):
+ super(MultipleChoice, self).__init__()
+ args = get_args()
+
+ init_method = init_method_normal(args.init_method_std)
+
+ self.language_model, self._language_model_key = get_language_model(
+ attention_mask_func=bert_attention_mask_func,
+ num_tokentypes=num_tokentypes,
+ add_pooler=True,
+ init_method=init_method,
+ scaled_init_method=scaled_init_method_normal(args.init_method_std,
+ args.num_layers))
+
+ # Multi-choice head.
+ self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
+ self.multichoice_head = get_linear_layer(args.hidden_size, 1,
+ init_method)
+ self._multichoice_head_key = 'multichoice_head'
+
+ def forward(self, input_ids, attention_mask, tokentype_ids):
+
+ # [batch, choices, sequence] --> [batch * choices, sequence] -->
+ # transformer --> [batch, choices] --> softmax
+
+ # Ensure the shape is [batch-size, choices, sequence]
+ assert len(input_ids.shape) == 3
+ assert len(attention_mask.shape) == 3
+ assert len(tokentype_ids.shape) == 3
+
+ # Reshape and treat choice dimension the same as batch.
+ num_choices = input_ids.shape[1]
+ input_ids = input_ids.view(-1, input_ids.size(-1))
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1))
+ tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
+
+ extended_attention_mask = bert_extended_attention_mask(
+ attention_mask, next(self.language_model.parameters()).dtype)
+ position_ids = bert_position_ids(input_ids)
+
+ _, pooled_output = self.language_model(input_ids,
+ position_ids,
+ extended_attention_mask,
+ tokentype_ids=tokentype_ids)
+
+ # Output.
+ multichoice_output = self.multichoice_dropout(pooled_output)
+ multichoice_logits = self.multichoice_head(multichoice_output)
+
+ # Reshape back to separate choices.
+ multichoice_logits = multichoice_logits.view(-1, num_choices)
+
+ return multichoice_logits
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """For easy load when model is combined with other heads,
+ add an extra key."""
+
+ state_dict_ = {}
+ state_dict_[self._language_model_key] \
+ = self.language_model.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ state_dict_[self._multichoice_head_key] \
+ = self.multichoice_head.state_dict(
+ destination, prefix, keep_vars)
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ self.language_model.load_state_dict(
+ state_dict[self._language_model_key], strict=strict)
+ if self._multichoice_head_key in state_dict:
+ self.multichoice_head.load_state_dict(
+ state_dict[self._multichoice_head_key], strict=strict)
+ else:
+ print_rank_0('***WARNING*** could not find {} in the checkpoint, '
+ 'initializing to random'.format(
+ self._multichoice_head_key))
diff --git a/megatron_lm/megatron/model/realm_model.py b/megatron_lm/megatron/model/realm_model.py
new file mode 100644
index 0000000..74bc5cf
--- /dev/null
+++ b/megatron_lm/megatron/model/realm_model.py
@@ -0,0 +1,204 @@
+import os
+import torch
+
+from megatron import get_args, print_rank_0
+from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
+from megatron.model import BertModel
+from megatron.module import MegatronModule
+from megatron import mpu
+from megatron.model.utils import get_linear_layer
+from megatron.model.utils import init_method_normal
+from megatron.model.language_model import get_language_model
+from megatron.model.utils import scaled_init_method_normal
+from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
+
+
+def general_ict_model_provider(only_query_model=False, only_block_model=False):
+ """Build the model."""
+ args = get_args()
+ assert args.ict_head_size is not None, \
+ "Need to specify --ict-head-size to provide an ICTBertModel"
+
+ assert args.model_parallel_size == 1, \
+ "Model parallel size > 1 not supported for ICT"
+
+ print_rank_0('building ICTBertModel...')
+
+ # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
+ model = ICTBertModel(
+ ict_head_size=args.ict_head_size,
+ num_tokentypes=2,
+ parallel_output=True,
+ only_query_model=only_query_model,
+ only_block_model=only_block_model)
+
+ return model
+
+
+class ICTBertModel(MegatronModule):
+ """Bert-based module for Inverse Cloze task."""
+ def __init__(self,
+ ict_head_size,
+ num_tokentypes=1,
+ parallel_output=True,
+ only_query_model=False,
+ only_block_model=False):
+ super(ICTBertModel, self).__init__()
+ bert_kwargs = dict(
+ ict_head_size=ict_head_size,
+ num_tokentypes=num_tokentypes,
+ parallel_output=parallel_output
+ )
+ assert not (only_block_model and only_query_model)
+ self.use_block_model = not only_query_model
+ self.use_query_model = not only_block_model
+
+ if self.use_query_model:
+ # this model embeds (pseudo-)queries - Embed_input in the paper
+ self.query_model = IREncoderBertModel(**bert_kwargs)
+ self._query_key = 'question_model'
+
+ if self.use_block_model:
+ # this model embeds evidence blocks - Embed_doc in the paper
+ self.block_model = IREncoderBertModel(**bert_kwargs)
+ self._block_key = 'context_model'
+
+ def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
+ """Run a forward pass for each of the models and return the respective embeddings."""
+ query_logits = self.embed_query(query_tokens, query_attention_mask)
+ block_logits = self.embed_block(block_tokens, block_attention_mask)
+ return query_logits, block_logits
+
+ def embed_query(self, query_tokens, query_attention_mask):
+ """Embed a batch of tokens using the query model"""
+ if self.use_query_model:
+ query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
+ query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
+ return query_ict_logits
+ else:
+ raise ValueError("Cannot embed query without query model.")
+
+ def embed_block(self, block_tokens, block_attention_mask):
+ """Embed a batch of tokens using the block model"""
+ if self.use_block_model:
+ block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
+ block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
+ return block_ict_logits
+ else:
+ raise ValueError("Cannot embed block without block model.")
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
+ """Save dict with state dicts of each of the models."""
+ state_dict_ = {}
+ if self.use_query_model:
+ state_dict_[self._query_key] \
+ = self.query_model.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+
+ if self.use_block_model:
+ state_dict_[self._block_key] \
+ = self.block_model.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Load the state dicts of each of the models"""
+ if self.use_query_model:
+ print("Loading ICT query model", flush=True)
+ self.query_model.load_state_dict(
+ state_dict[self._query_key], strict=strict)
+
+ if self.use_block_model:
+ print("Loading ICT block model", flush=True)
+ self.block_model.load_state_dict(
+ state_dict[self._block_key], strict=strict)
+
+ def init_state_dict_from_bert(self):
+ """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
+ args = get_args()
+ tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
+ if not os.path.isfile(tracker_filename):
+ raise FileNotFoundError("Could not find BERT load for ICT")
+ with open(tracker_filename, 'r') as f:
+ iteration = int(f.read().strip())
+ assert iteration > 0
+
+ checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
+ if mpu.get_data_parallel_rank() == 0:
+ print('global rank {} is loading checkpoint {}'.format(
+ torch.distributed.get_rank(), checkpoint_name))
+
+ try:
+ state_dict = torch.load(checkpoint_name, map_location='cpu')
+ except BaseException:
+ raise ValueError("Could not load checkpoint")
+
+ # load the LM state dict into each model
+ model_dict = state_dict['model']['language_model']
+ self.query_model.language_model.load_state_dict(model_dict)
+ self.block_model.language_model.load_state_dict(model_dict)
+
+ # give each model the same ict_head to begin with as well
+ query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
+ self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
+
+
+class IREncoderBertModel(MegatronModule):
+ """BERT-based encoder for queries or blocks used for learned information retrieval."""
+ def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
+ super(IREncoderBertModel, self).__init__()
+ args = get_args()
+
+ self.ict_head_size = ict_head_size
+ self.parallel_output = parallel_output
+ init_method = init_method_normal(args.init_method_std)
+ scaled_init_method = scaled_init_method_normal(args.init_method_std,
+ args.num_layers)
+
+ self.language_model, self._language_model_key = get_language_model(
+ attention_mask_func=bert_attention_mask_func,
+ num_tokentypes=num_tokentypes,
+ add_pooler=True,
+ init_method=init_method,
+ scaled_init_method=scaled_init_method)
+
+ self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
+ self._ict_head_key = 'ict_head'
+
+ def forward(self, input_ids, attention_mask, tokentype_ids=None):
+ extended_attention_mask = bert_extended_attention_mask(
+ attention_mask, next(self.language_model.parameters()).dtype)
+ position_ids = bert_position_ids(input_ids)
+
+ lm_output, pooled_output = self.language_model(
+ input_ids,
+ position_ids,
+ extended_attention_mask,
+ tokentype_ids=tokentype_ids)
+
+ # Output.
+ ict_logits = self.ict_head(pooled_output)
+ return ict_logits, None
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """For easy load when model is combined with other heads,
+ add an extra key."""
+
+ state_dict_ = {}
+ state_dict_[self._language_model_key] \
+ = self.language_model.state_dict_for_save_checkpoint(
+ destination, prefix, keep_vars)
+ state_dict_[self._ict_head_key] \
+ = self.ict_head.state_dict(destination, prefix, keep_vars)
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+ self.language_model.load_state_dict(
+ state_dict[self._language_model_key], strict=strict)
+ self.ict_head.load_state_dict(
+ state_dict[self._ict_head_key], strict=strict)
+
+
diff --git a/megatron_lm/megatron/model/transformer.py b/megatron_lm/megatron/model/transformer.py
new file mode 100644
index 0000000..69c3d44
--- /dev/null
+++ b/megatron_lm/megatron/model/transformer.py
@@ -0,0 +1,1079 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Transformer."""
+
+import math
+import torch
+import torch.nn.functional as F
+
+from megatron import get_args
+from megatron import mpu
+from megatron.mpu import LayerNorm
+from megatron.module import MegatronModule
+from megatron.checkpointing import get_checkpoint_version
+from megatron.model.fused_softmax import FusedScaleMaskSoftmax
+from megatron.model.fused_bias_gelu import bias_gelu_impl
+
+import deepspeed
+
+# flags required to enable jit fusion kernels
+torch._C._jit_set_profiling_mode(False)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_override_can_fuse_on_cpu(True)
+torch._C._jit_override_can_fuse_on_gpu(True)
+
+""" We use the following notation throughout this file:
+ h: hidden size
+ n: number of attention heads
+ p: number of model parallel partitions
+ np: n/p
+ hp: h/p
+ hn: h/n
+ b: batch size
+ s: sequence length
+ l: number of layers
+ Transformer takes input of size [s, b, h] and returns a
+ tensor of the same size. We use the following arguments:
+ hyperparameters: transformer hyperparameters
+ attention_mask_func: a function that takes `unmaksed-attention-scores`
+ with size [b, np, s, s] and an `attention-mask` and will apply
+ the masking. The function should return a masked score of the
+ same size [b, np, s, s].
+ masked-attention-scores = attention_mask_func(
+ unmaksed-attention-scores, attention-mask)
+"""
+
+class ParallelMLP(MegatronModule):
+ """MLP.
+
+ MLP will take the input with h hidden state, project it to 4*h
+ hidden dimension, perform nonlinear transformation, and project the
+ state back into h hidden dimension. At the end, dropout is also
+ applied.
+ """
+
+ def __init__(self, init_method, output_layer_init_method):
+ super(ParallelMLP, self).__init__()
+ args = get_args()
+
+ if not args.memory_centric_tiled_linear:
+ self.dense_ffn_hidden = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ args.intermediate_size,
+ gather_output=False,
+ init_method=init_method,
+ skip_bias_add=True)
+ else:
+ self.dense_ffn_hidden = deepspeed.zero.TiledLinearReturnBias(
+ in_features=args.hidden_size,
+ out_features=4*args.hidden_size,
+ linear_cls=mpu.ColumnParallelLinear,
+ in_splits=args.tile_factor,
+ out_splits=4*args.tile_factor,
+ combine_out_splits=True,
+ gather_output=False,
+ init_method=init_method,
+ skip_bias_add=True)
+
+ self.bias_gelu_fusion = args.bias_gelu_fusion
+ self.activation_type = args.activation_type
+ self.is_gated = args.activation_type in ['geglu']
+
+ self.activation_func = F.gelu
+
+ if self.is_gated:
+ self.dense_ffn_gate = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ args.intermediate_size,
+ gather_output=False,
+ init_method=init_method,
+ skip_bias_add=False,
+ )
+
+ if not args.memory_centric_tiled_linear:
+ self.dense_ffn_output = mpu.RowParallelLinear(
+ args.intermediate_size,
+ args.hidden_size,
+ input_is_parallel=True,
+ init_method=output_layer_init_method,
+ skip_bias_add=True)
+ else:
+ self.dense_ffn_output = deepspeed.zero.TiledLinearReturnBias(
+ in_features=4*args.hidden_size,
+ out_features=args.hidden_size,
+ linear_cls=mpu.RowParallelLinear,
+ in_splits=4*args.tile_factor,
+ out_splits=args.tile_factor,
+ input_is_already_split=False,
+ combine_out_splits=True,
+ input_is_parallel=True,
+ init_method=output_layer_init_method,
+ skip_bias_add=True)
+
+ def forward(self, hidden_states):
+ intermediate_parallel, bias_parallel = self.dense_ffn_hidden(hidden_states)
+
+ if self.bias_gelu_fusion:
+ intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
+ else:
+ intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
+
+ if self.is_gated:
+ gate = self.dense_ffn_gate(hidden_states)[0]
+ intermediate_gated = intermediate_parallel * gate
+ else:
+ intermediate_gated = intermediate_parallel
+
+ output, output_bias = self.dense_ffn_output(intermediate_gated)
+ return output, output_bias
+
+
+class ParallelSelfAttention(MegatronModule):
+ """Parallel self-attention layer abstract class.
+
+ Self-attention layer takes input with size [b, s, h]
+ and returns output of the same size.
+ """
+
+ def __init__(self, attention_mask_func, init_method,
+ output_layer_init_method, layer_number):
+ super(ParallelSelfAttention, self).__init__()
+ args = get_args()
+ self.fp16 = args.fp16
+
+ self.attention_mask_func = attention_mask_func
+ self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
+ self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
+ if self.apply_query_key_layer_scaling:
+ self.attention_softmax_in_fp32 = True
+ self.layer_number = max(1, layer_number)
+ self.pos_encoding_type = args.pos_encoding_type
+
+ # Per attention head and per partition values.
+ world_size = mpu.get_model_parallel_world_size()
+ self.hidden_size_per_partition = mpu.divide(args.hidden_size,
+ world_size)
+ self.hidden_size_per_attention_head = mpu.divide(
+ args.hidden_size, args.num_attention_heads)
+ self.num_attention_heads_per_partition = mpu.divide(
+ args.num_attention_heads, world_size)
+
+ # Strided linear layer.
+ if not args.memory_centric_tiled_linear:
+ self.query_key_value = mpu.ColumnParallelLinear(
+ args.hidden_size,
+ 3 * args.hidden_size,
+ gather_output=False,
+ init_method=init_method)
+ else:
+ self.query_key_value = deepspeed.zero.TiledLinearReturnBias(
+ in_features=args.hidden_size,
+ out_features=3*args.hidden_size,
+ linear_cls=mpu.ColumnParallelLinear,
+ gather_output=False,
+ init_method=init_method,
+ in_splits=args.tile_factor,
+ out_splits=3*args.tile_factor,
+ combine_out_splits=True
+ )
+
+ coeff = None
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
+ if self.apply_query_key_layer_scaling:
+ coeff = self.layer_number
+ self.norm_factor *= coeff
+
+ self.scale_mask_softmax = FusedScaleMaskSoftmax(
+ self.fp16,
+ args.scaled_upper_triang_masked_softmax_fusion,
+ args.scaled_masked_softmax_fusion,
+ self.attention_mask_func,
+ self.attention_softmax_in_fp32,
+ coeff)
+
+ # Dropout. Note that for a single iteration, this layer will generate
+ # different outputs on different number of parallel partitions but
+ # on average it should not be partition dependent.
+ self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
+
+ # Output.
+ if not args.memory_centric_tiled_linear:
+ self.dense = mpu.RowParallelLinear(
+ args.hidden_size,
+ args.hidden_size,
+ input_is_parallel=True,
+ init_method=output_layer_init_method,
+ skip_bias_add=True)
+ else:
+ self.dense = deepspeed.zero.TiledLinearReturnBias(
+ in_features=args.hidden_size,
+ out_features=args.hidden_size,
+ linear_cls=mpu.RowParallelLinear,
+ input_is_parallel=True,
+ init_method=output_layer_init_method,
+ skip_bias_add=True,
+ out_splits=args.tile_factor,
+ in_splits=args.tile_factor,
+ combine_out_splits=True
+ )
+
+ if deepspeed.checkpointing.is_configured():
+ global get_cuda_rng_tracker, checkpoint
+ get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
+ checkpoint = deepspeed.checkpointing.checkpoint
+
+ if args.pos_encoding_type == 'rotary':
+ self.rotary_position_encoding = RotaryPositionEncoding(
+ args.max_position_embeddings,
+ self.hidden_size_per_attention_head,
+ args.params_dtype,
+ )
+
+ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
+ input_shape = mixed_layer.size();
+ if num_splits_first:
+ """[s, b, num_splits * np * hn]
+ -->(view) [s, b, num_splits, np, hn]
+ -->(tranpose) [s, b, np, num_splits, hn]
+ -->(view) [s, b, np * num_splits * hn] """
+
+ intermediate_shape = input_shape[:-1] +\
+ (num_splits, self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+
+ mixed_layer = mixed_layer.view(*intermediate_shape)
+ mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
+ else:
+ """[s, b, np * hn * num_splits]
+ -->(view) [s, b, np, hn, num_splits]
+ -->(tranpose) [s, b, np, num_splits, hn]
+ -->(view) [s, b, np * num_splits * hn] """
+
+ intermediate_shape = input_shape[:-1] +\
+ (self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head, num_splits)
+
+ mixed_layer = mixed_layer.view(*intermediate_shape)
+ mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
+ mixed_layer = mixed_layer.view(*input_shape)
+
+ return mixed_layer
+
+ def forward(self, hidden_states, attention_mask, layer_past=None,
+ get_key_value=False):
+ # hidden_states: [sq, b, h]
+
+ # =====================
+ # Query, Key, and Value
+ # =====================
+
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
+ mixed_x_layer, _ = self.query_key_value(hidden_states)
+
+ checkpoint_version = get_checkpoint_version()
+ if checkpoint_version is not None:
+ if checkpoint_version == 0:
+ # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
+ mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
+ elif checkpoint_version == 1.0:
+ # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
+ mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
+ else:
+ pass # already [sq, b, (np * 3 * hn)]
+
+ # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head)
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
+ (query_layer,
+ key_layer,
+ value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ if self.pos_encoding_type == 'rotary':
+ context_position = 0 if layer_past is None else layer_past[2]
+ query_layer = self.rotary_position_encoding(query_layer, context_position)
+ key_layer = self.rotary_position_encoding(key_layer, context_position)
+
+ # ==================================
+ # Adjust key and value for inference
+ # ==================================
+
+ if layer_past is not None:
+ past_key, past_value, sq_length = layer_past
+ key_layer = torch.cat((past_key.type_as(key_layer),
+ key_layer), dim=0)
+ value_layer = torch.cat((past_value.type_as(value_layer),
+ value_layer), dim=0)
+ sq_length += 1
+ else:
+ sq_length = key_layer.size()[0]
+
+ if get_key_value:
+ present = (key_layer, value_layer, sq_length)
+
+
+ # ===================================
+ # Raw attention scores. [b, np, s, s]
+ # ===================================
+
+ # [b, np, sq, sk]
+ output_size = (query_layer.size(1),
+ query_layer.size(2),
+ query_layer.size(0),
+ key_layer.size(0))
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.view(output_size[2],
+ output_size[0] * output_size[1], -1)
+ key_layer = key_layer.view(output_size[3],
+ output_size[0] * output_size[1], -1)
+
+ # preallocting result tensor: [b * np, sq, sk]
+ matmul_result = torch.empty(
+ output_size[0]*output_size[1],
+ output_size[2],
+ output_size[3],
+ dtype=query_layer.dtype,
+ device=torch.cuda.current_device())
+
+ # Raw attention scores. [b * np, sq, sk]
+ matmul_result = torch.baddbmm(matmul_result,
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
+ key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk]
+ beta=0.0, alpha=(1.0/self.norm_factor))
+
+ # change view to [b, np, sq, sk]
+ attention_scores = matmul_result.view(*output_size)
+
+
+ # ==================================================
+ # Update attention mask for inference. [b, np, sq, sk]
+ # ==================================================
+
+ if get_key_value:
+ with torch.no_grad():
+ if layer_past is not None:
+ attention_mask = attention_mask[
+ ...,
+ attention_scores.size(3) - 1,
+ :attention_scores.size(3)].unsqueeze(2)
+ else:
+ attention_mask = attention_mask[
+ ...,
+ :attention_scores.size(3),
+ :attention_scores.size(3)]
+
+
+ # ===========================
+ # Attention probs and dropout
+ # ===========================
+
+ # attention scores and attention mask [b, np, sq, sk]
+ attention_probs = self.scale_mask_softmax(attention_scores,
+ attention_mask)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ with mpu.get_cuda_rng_tracker().fork():
+ attention_probs = self.attention_dropout(attention_probs)
+
+
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sk, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (value_layer.size(1),
+ value_layer.size(2),
+ query_layer.size(0),
+ value_layer.size(3))
+
+ # change view [sk, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0),
+ output_size[0] * output_size[1], -1)
+
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1],
+ output_size[2], -1)
+
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))
+
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + \
+ (self.hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output, bias = self.dense(context_layer)
+
+ if get_key_value:
+ output = [output, present]
+
+ return output, bias
+
+
+def bias_dropout_add(x, bias, residual, prob, training) :
+ # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
+ out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
+ out = residual + out
+ return out
+
+
+def get_bias_dropout_add(training):
+ def _bias_dropout_add(x, bias, residual, prob):
+ return bias_dropout_add(x, bias, residual, prob, training)
+ return _bias_dropout_add
+
+
+@torch.jit.script
+def bias_dropout_add_fused_train(x, bias, residual, prob) :
+ # type: (Tensor, Tensor, Tensor, float) -> Tensor
+ return bias_dropout_add(x, bias, residual, prob, True)
+
+
+@torch.jit.script
+def bias_dropout_add_fused_inference(x, bias, residual, prob) :
+ # type: (Tensor, Tensor, Tensor, float) -> Tensor
+ return bias_dropout_add(x, bias, residual, prob, False)
+
+
+class ParallelTransformerLayer(MegatronModule):
+ """A single transformer layer.
+
+ Transformore layer takes input with size [b, s, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, attention_mask_func, init_method,
+ output_layer_init_method, layer_number):
+ args = get_args()
+
+ super(ParallelTransformerLayer, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm \
+ = args.apply_residual_connection_post_layernorm
+
+ # Memory-saving optimization
+ self.scattered_attn_output = args.scattered_embeddings
+
+ # Layernorm on the input data.
+ if self.layer_number > 1:
+ self.input_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon
+ )
+
+ # Self attention.
+ self.attention = ParallelSelfAttention(attention_mask_func, init_method,
+ output_layer_init_method,
+ layer_number)
+ self.hidden_dropout = args.hidden_dropout
+ self.bias_dropout_fusion = args.bias_dropout_fusion
+
+ # Layernorm on the input data.
+ self.post_attention_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+
+ # MLP
+ self.mlp = ParallelMLP(init_method,
+ output_layer_init_method)
+
+
+ def forward(self, hidden_states, attention_mask, layer_past=None,
+ get_key_value=False):
+ # hidden_states: [b, s, h]
+
+ # Layer norm at the begining of the transformer layer.
+ if self.layer_number > 1:
+ attention_input = self.input_layernorm(hidden_states)
+ else:
+ attention_input = hidden_states
+
+ # Self attention.
+ attention_output, attention_bias = \
+ self.attention(attention_input,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=get_key_value)
+
+ if get_key_value:
+ attention_output, presents = attention_output
+
+ if self.scattered_attn_output:
+ attention_output = mpu.scatter_to_model_parallel_region(attention_output)
+ attention_bias = mpu.scatter_to_model_parallel_region(attention_bias)
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = attention_input
+ else:
+ residual = hidden_states
+
+ if self.scattered_attn_output:
+ residual = mpu.scatter_to_model_parallel_region(residual)
+
+ # jit scripting for a nn.module (with dropout) is not
+ # trigerring the fusion kernel. For now, we use two
+ # different nn.functional routines to account for varying
+ # dropout semantics during training and inference phases.
+ if self.bias_dropout_fusion:
+ if self.training:
+ bias_dropout_add_func = bias_dropout_add_fused_train
+ else:
+ bias_dropout_add_func = bias_dropout_add_fused_inference
+ else:
+ bias_dropout_add_func = get_bias_dropout_add(self.training)
+
+ #re-enable torch grad to enable fused optimization.
+ with torch.enable_grad():
+ layernorm_input = bias_dropout_add_func(
+ attention_output,
+ attention_bias.expand_as(residual),
+ residual,
+ self.hidden_dropout)
+
+ # Collect the scattered result from the fused dropout.
+ if self.scattered_attn_output:
+ layernorm_input = mpu.gather_from_model_parallel_region(layernorm_input)
+ # Attention output/bias are not used again, so no need to gather
+
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+ # MLP.
+ mlp_output, mlp_bias = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ #re-enable torch grad to enable fused optimization.
+ with torch.enable_grad():
+ output = bias_dropout_add_func(
+ mlp_output,
+ mlp_bias.expand_as(residual),
+ residual,
+ self.hidden_dropout)
+
+ if get_key_value:
+ output = [output, presents]
+
+ return output
+
+class ParallelTransformerLayerPart1(MegatronModule):
+ """A single transformer layer.
+
+ Transformore layer takes input with size [b, s, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, attention_mask_func, init_method,
+ output_layer_init_method, layer_number):
+ args = get_args()
+
+ super(ParallelTransformerLayerPart1, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm \
+ = args.apply_residual_connection_post_layernorm
+
+ # Layernorm on the input data.
+ self.input_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+
+ # Self attention.
+ self.attention = ParallelSelfAttention(attention_mask_func, init_method,
+ output_layer_init_method,
+ layer_number)
+ self.hidden_dropout = args.hidden_dropout
+ self.bias_dropout_fusion = args.bias_dropout_fusion
+
+
+ def forward(self, hidden_states, attention_mask, layer_past=None,
+ get_key_value=False):
+ # hidden_states: [b, s, h]
+
+ # Layer norm at the begining of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, attention_bias = \
+ self.attention(layernorm_output,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=get_key_value)
+
+ presents = None
+ if get_key_value:
+ raise NotImplementedError('get_key_value param is not yet supported with split-transformers')
+ attention_output, presents = attention_output
+
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ if self.scattered_attn_output:
+ residual = mpu.scatter_to_model_parallel_region(residual)
+
+ # jit scripting for a nn.module (with dropout) is not
+ # trigerring the fusion kernel. For now, we use two
+ # different nn.functional routines to account for varying
+ # dropout semantics during training and inference phases.
+ if self.bias_dropout_fusion:
+ if self.training:
+ bias_dropout_add_func = bias_dropout_add_fused_train
+ else:
+ bias_dropout_add_func = bias_dropout_add_fused_inference
+ else:
+ bias_dropout_add_func = get_bias_dropout_add(self.training)
+
+ #re-enable torch grad to enable fused optimization.
+ with torch.enable_grad():
+ layernorm_input = bias_dropout_add_func(
+ attention_output,
+ attention_bias.expand_as(residual),
+ residual,
+ self.hidden_dropout)
+
+ return layernorm_input
+
+class ParallelTransformerLayerPart2(MegatronModule):
+ """A single transformer layer.
+
+ Transformore layer takes input with size [b, s, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, attention_mask_func, init_method,
+ output_layer_init_method, layer_number):
+ args = get_args()
+
+ super(ParallelTransformerLayerPart2, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm \
+ = args.apply_residual_connection_post_layernorm
+
+ self.hidden_dropout = args.hidden_dropout
+ self.bias_dropout_fusion = args.bias_dropout_fusion
+
+ # Layernorm on the input data.
+ self.post_attention_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+
+ # MLP
+ self.mlp = ParallelMLP(init_method,
+ output_layer_init_method)
+
+
+ def forward(self, layernorm_input, attention_mask, presents=None, layer_past=None,
+ get_key_value=False):
+ # hidden_states: [b, s, h]
+
+ # Collect the scattered result from the fused dropout.
+ if self.scattered_attn_output:
+ layernorm_input = mpu.gather_from_model_parallel_region(layernorm_input)
+ # Attention output/bias are not used again, so no need to gather
+
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+ # MLP.
+ mlp_output, mlp_bias = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ # jit scripting for a nn.module (with dropout) is not
+ # trigerring the fusion kernel. For now, we use two
+ # different nn.functional routines to account for varying
+ # dropout semantics during training and inference phases.
+ if self.bias_dropout_fusion:
+ if self.training:
+ bias_dropout_add_func = bias_dropout_add_fused_train
+ else:
+ bias_dropout_add_func = bias_dropout_add_fused_inference
+ else:
+ bias_dropout_add_func = get_bias_dropout_add(self.training)
+
+ #re-enable torch grad to enable fused optimization.
+ with torch.enable_grad():
+ output = bias_dropout_add_func(
+ mlp_output,
+ mlp_bias.expand_as(residual),
+ residual,
+ self.hidden_dropout)
+
+ if get_key_value:
+ output = [output, presents]
+
+ return output
+
+class ParallelTransformerLayerPart1(MegatronModule):
+ """A single transformer layer.
+
+ Transformore layer takes input with size [b, s, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, attention_mask_func, init_method,
+ output_layer_init_method, layer_number):
+ args = get_args()
+
+ super(ParallelTransformerLayerPart1, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm \
+ = args.apply_residual_connection_post_layernorm
+
+ # Layernorm on the input data.
+ self.input_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+
+ # Self attention.
+ self.attention = ParallelSelfAttention(attention_mask_func, init_method,
+ output_layer_init_method,
+ layer_number)
+ self.hidden_dropout = args.hidden_dropout
+ self.bias_dropout_fusion = args.bias_dropout_fusion
+
+
+ def forward(self, hidden_states, attention_mask, layer_past=None,
+ get_key_value=False):
+ # hidden_states: [b, s, h]
+
+ # Layer norm at the begining of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, attention_bias = \
+ self.attention(layernorm_output,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=get_key_value)
+
+ presents = None
+ if get_key_value:
+ raise NotImplementedError('get_key_value param is not yet supported with split-transformers')
+ attention_output, presents = attention_output
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # jit scripting for a nn.module (with dropout) is not
+ # trigerring the fusion kernel. For now, we use two
+ # different nn.functional routines to account for varying
+ # dropout semantics during training and inference phases.
+ if self.bias_dropout_fusion:
+ if self.training:
+ bias_dropout_add_func = bias_dropout_add_fused_train
+ else:
+ bias_dropout_add_func = bias_dropout_add_fused_inference
+ else:
+ bias_dropout_add_func = get_bias_dropout_add(self.training)
+
+ #re-enable torch grad to enable fused optimization.
+ with torch.enable_grad():
+ layernorm_input = bias_dropout_add_func(
+ attention_output,
+ attention_bias.expand_as(residual),
+ residual,
+ self.hidden_dropout)
+
+ return layernorm_input
+
+class ParallelTransformerLayerPart2(MegatronModule):
+ """A single transformer layer.
+
+ Transformore layer takes input with size [b, s, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, attention_mask_func, init_method,
+ output_layer_init_method, layer_number):
+ args = get_args()
+
+ super(ParallelTransformerLayerPart2, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm \
+ = args.apply_residual_connection_post_layernorm
+
+ self.hidden_dropout = args.hidden_dropout
+ self.bias_dropout_fusion = args.bias_dropout_fusion
+
+ # Layernorm on the input data.
+ self.post_attention_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon)
+
+ # MLP
+ self.mlp = ParallelMLP(init_method,
+ output_layer_init_method)
+
+
+ def forward(self, layernorm_input, attention_mask, presents=None, layer_past=None,
+ get_key_value=False):
+ # hidden_states: [b, s, h]
+
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+ # MLP.
+ mlp_output, mlp_bias = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ # jit scripting for a nn.module (with dropout) is not
+ # trigerring the fusion kernel. For now, we use two
+ # different nn.functional routines to account for varying
+ # dropout semantics during training and inference phases.
+ if self.bias_dropout_fusion:
+ if self.training:
+ bias_dropout_add_func = bias_dropout_add_fused_train
+ else:
+ bias_dropout_add_func = bias_dropout_add_fused_inference
+ else:
+ bias_dropout_add_func = get_bias_dropout_add(self.training)
+
+ #re-enable torch grad to enable fused optimization.
+ with torch.enable_grad():
+ output = bias_dropout_add_func(
+ mlp_output,
+ mlp_bias.expand_as(residual),
+ residual,
+ self.hidden_dropout)
+
+ if get_key_value:
+ output = [output, presents]
+
+ return output
+
+
+class ParallelTransformer(MegatronModule):
+ """Transformer class."""
+
+ def __init__(self, attention_mask_func,
+ init_method, output_layer_init_method):
+ super(ParallelTransformer, self).__init__()
+ args = get_args()
+
+ # Store activation checkpoiting flag.
+ self.checkpoint_activations = args.checkpoint_activations
+ self.checkpoint_num_layers = args.checkpoint_num_layers
+
+ # Number of layers:
+ self.num_layers = args.num_layers
+ self.num_unique_layers = args.num_unique_layers
+ if self.num_unique_layers is None:
+ self.num_unique_layers = self.num_layers
+ assert self.num_layers % self.num_unique_layers == 0, \
+ 'number of layers should be divisible by number of unique layers'
+ self.param_sharing_style = args.param_sharing_style
+
+ # Transformer layers.
+ def build_layer(layer_number):
+ return ParallelTransformerLayer(
+ attention_mask_func, init_method,
+ output_layer_init_method, layer_number)
+
+ def build_layer_part1(layer_number):
+ return ParallelTransformerLayerPart1(
+ attention_mask_func, init_method,
+ output_layer_init_method, layer_number)
+ def build_layer_part2(layer_number):
+ return ParallelTransformerLayerPart2(
+ attention_mask_func, init_method,
+ output_layer_init_method, layer_number)
+
+ if args.split_transformers:
+ layers = []
+ for i in range(self.num_unique_layers):
+ layers.append(build_layer_part1(i + 1))
+ layers.append(build_layer_part2(i + 1))
+ self.layers = torch.nn.ModuleList(layers)
+ self.num_layers *= 2
+ self.num_unique_layers *= 2
+ else:
+ self.layers = torch.nn.ModuleList(
+ [build_layer(i + 1) for i in range(self.num_unique_layers)])
+
+ # Print layer ordering.
+ if self.num_layers != self.num_unique_layers:
+ if torch.distributed.get_rank() == 0:
+ print('> will be using the following layer ordering:')
+ for i in range(self.num_layers):
+ print(' layer id: {:3d} --> unique layer id: '
+ '{:3d}'.format(i, self._get_layer_index(i)),
+ flush=True)
+
+ # Final layer norm before output.
+ # self.final_layernorm = LayerNorm(
+ # args.hidden_size,
+ # eps=args.layernorm_epsilon)
+
+ if deepspeed.checkpointing.is_configured():
+ global get_cuda_rng_tracker, checkpoint
+ get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
+ checkpoint = deepspeed.checkpointing.checkpoint
+
+ def _get_layer_index(self, layer_number):
+ if self.param_sharing_style == 'grouped':
+ return layer_number % self.num_unique_layers
+ if self.param_sharing_style == 'spaced':
+ return layer_number // (self.num_layers // self.num_unique_layers)
+ assert False, 'should not be here'
+
+ def _get_layer(self, layer_number):
+ return self.layers[self._get_layer_index(layer_number)]
+
+ def _checkpointed_forward(self, hidden_states, attention_mask):
+ """Forward method with activation checkpointing."""
+ def custom(start, end):
+ def custom_forward(*inputs):
+ x_ = inputs[0]
+ for index in range(start, end):
+ layer = self._get_layer(index)
+ x_ = layer(x_, inputs[1])
+ return x_
+ return custom_forward
+
+ # Make sure memory is freed.
+ mpu.reset_checkpointed_activations_memory_buffer()
+ l = 0
+ while l < self.num_layers:
+ hidden_states = mpu.checkpoint(
+ custom(l, l + self.checkpoint_num_layers),
+ hidden_states, attention_mask)
+ l += self.checkpoint_num_layers
+
+ return hidden_states
+
+ def forward(self, hidden_states, attention_mask, layer_past=None,
+ get_key_value=False):
+
+ # Checks
+ if layer_past is not None:
+ assert get_key_value, \
+ 'for not None values in layer_past, ' \
+ 'expected get_key_value to be set'
+ if get_key_value:
+ assert not self.checkpoint_activations, \
+ 'get_key_value does not work with ' \
+ 'activation checkpointing'
+
+ # data format change to avoid explicit tranposes : [b s h] --> [s b h]
+ hidden_states = hidden_states.transpose(0, 1).contiguous()
+
+ if self.checkpoint_activations:
+ hidden_states = self._checkpointed_forward(hidden_states,
+ attention_mask)
+ else:
+ if get_key_value:
+ presents = []
+ for index in range(self.num_layers):
+ layer = self._get_layer(index)
+ past = None
+ if layer_past is not None:
+ past = layer_past[index]
+ hidden_states = layer(hidden_states,
+ attention_mask,
+ layer_past=past,
+ get_key_value=get_key_value)
+ if get_key_value:
+ hidden_states, present = hidden_states
+ presents.append(present)
+
+ # reverting data format change [s b h] --> [b s h]
+ output = hidden_states.transpose(0, 1).contiguous()
+
+ # # Final layer norm.
+ # output = self.final_layernorm(hidden_states)
+
+ if get_key_value:
+ output = [output, presents]
+
+ return output
+
+
+class RotaryPositionEncoding(MegatronModule):
+ def __init__(self, max_seq_length, hidden_size_per_attention_head, dtype):
+ super(RotaryPositionEncoding, self).__init__()
+ cos_cached, sin_cached = RotaryPositionEncoding.get_cache_multipliers(
+ max_seq_length, hidden_size_per_attention_head, dtype
+ )
+ self.register_buffer("cos_cached", cos_cached.unsqueeze(1).unsqueeze(2), persistent=False)
+ self.register_buffer("sin_cached", sin_cached.unsqueeze(1).unsqueeze(2), persistent=False)
+
+ def forward(self, hidden_state, context_position):
+ seq_length = hidden_state.shape[0]
+ cache_slice = slice(context_position, context_position + seq_length)
+ return self.apply_rotary_position_encoding(
+ hidden_state, self.cos_cached[cache_slice], self.sin_cached[cache_slice]
+ )
+
+ @staticmethod
+ def get_cache_multipliers(max_seq_length, hidden_size, dtype):
+ inv_freqs = 1e-4 ** (torch.arange(0, hidden_size, 2, dtype=torch.float) / hidden_size)
+ positions = torch.arange(max_seq_length, dtype=torch.float)
+ angles = positions.unsqueeze(-1) * inv_freqs
+
+ cos, sin = torch.cos(angles), torch.sin(angles)
+ return cos.to(dtype), sin.to(dtype)
+
+ @staticmethod
+ def apply_rotary_position_encoding(hidden_state, cos_cached, sin_cached):
+ sq, b, np, hn = hidden_state.shape
+ half_hn = hn // 2
+ left, right = hidden_state[..., :half_hn], hidden_state[..., half_hn:]
+ encoded_left = cos_cached * left - sin_cached * right
+ encoded_right = sin_cached * left + cos_cached * right
+ return torch.cat((encoded_left, encoded_right), dim=3)
diff --git a/megatron_lm/megatron/model/utils.py b/megatron_lm/megatron/model/utils.py
new file mode 100644
index 0000000..c309c4b
--- /dev/null
+++ b/megatron_lm/megatron/model/utils.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for models."""
+
+import math
+
+import torch
+
+from .transformer import LayerNorm
+
+
+def init_method_normal(sigma):
+ """Init method based on N(0, sigma)."""
+ def init_(tensor):
+ return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
+
+ return init_
+
+
+def scaled_init_method_normal(sigma, num_layers):
+ """Init method based on N(0, sigma/sqrt(2*num_layers)."""
+ std = sigma / math.sqrt(2.0 * num_layers)
+
+ def init_(tensor):
+ return torch.nn.init.normal_(tensor, mean=0.0, std=std)
+
+ return init_
+
+
+def get_linear_layer(rows, columns, init_method):
+ """Simple linear layer with weight initialization."""
+ layer = torch.nn.Linear(rows, columns)
+ init_method(layer.weight)
+ with torch.no_grad():
+ layer.bias.zero_()
+ return layer
+
+@torch.jit.script
+def gelu_impl(x):
+ """OpenAI's gelu implementation."""
+ return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
+ (1.0 + 0.044715 * x * x)))
+def openai_gelu(x):
+ return gelu_impl(x)
+
+#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
+@torch.jit.script
+def erf_gelu(x):
+ return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
+
+def get_params_for_weight_decay_optimization(module):
+ """Divide params into with-weight-decay and without-weight-decay groups.
+ Layernorms and baises will have no weight decay but the rest will.
+ """
+ weight_decay_params = {'params': []}
+ no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
+ for module_ in module.modules():
+ if isinstance(module_, LayerNorm):
+ no_weight_decay_params['params'].extend(
+ [p for p in list(module_._parameters.values())
+ if p is not None])
+ else:
+ weight_decay_params['params'].extend(
+ [p for n, p in list(module_._parameters.items())
+ if p is not None and n != 'bias'])
+ no_weight_decay_params['params'].extend(
+ [p for n, p in list(module_._parameters.items())
+ if p is not None and n == 'bias'])
+
+ return weight_decay_params, no_weight_decay_params
diff --git a/megatron_lm/megatron/module.py b/megatron_lm/megatron/module.py
new file mode 100644
index 0000000..a78c228
--- /dev/null
+++ b/megatron_lm/megatron/module.py
@@ -0,0 +1,31 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron Module"""
+
+import torch
+
+
+class MegatronModule(torch.nn.Module):
+ """Megatron specific extentions of torch Module."""
+
+ def __init__(self):
+ super(MegatronModule, self).__init__()
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ """Use this function to override the state dict for
+ saving checkpoints."""
+ return self.state_dict(destination, prefix, keep_vars)
diff --git a/megatron_lm/megatron/mpu/__init__.py b/megatron_lm/megatron/mpu/__init__.py
new file mode 100644
index 0000000..cf138b5
--- /dev/null
+++ b/megatron_lm/megatron/mpu/__init__.py
@@ -0,0 +1,53 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model parallel utility interface."""
+
+from .cross_entropy import vocab_parallel_cross_entropy
+
+from .data import broadcast_data
+
+from .grads import clip_grad_norm
+
+from .initialize import is_unitialized
+from .initialize import destroy_model_parallel
+from .initialize import get_data_parallel_group
+from .initialize import get_data_parallel_rank
+from .initialize import get_data_parallel_world_size
+from .initialize import get_model_parallel_group
+from .initialize import get_model_parallel_rank, set_model_parallel_rank
+from .initialize import get_model_parallel_src_rank
+from .initialize import get_model_parallel_world_size, set_model_parallel_world_size
+from .initialize import initialize_model_parallel
+from .initialize import model_parallel_is_initialized
+
+from .layers import LayerNorm
+from .layers import ColumnParallelLinear
+from .layers import RowParallelLinear
+from .layers import VocabParallelEmbedding
+
+from .mappings import copy_to_model_parallel_region
+from .mappings import gather_from_model_parallel_region
+from .mappings import reduce_from_model_parallel_region
+from .mappings import scatter_to_model_parallel_region
+
+from .random import checkpoint
+from .random import get_cuda_rng_tracker
+from .random import init_checkpointed_activations_memory_buffer
+from .random import model_parallel_cuda_manual_seed
+from .random import reset_checkpointed_activations_memory_buffer
+
+from .utils import divide
+from .utils import split_tensor_along_last_dim
diff --git a/megatron_lm/megatron/mpu/cross_entropy.py b/megatron_lm/megatron/mpu/cross_entropy.py
new file mode 100644
index 0000000..79ea83d
--- /dev/null
+++ b/megatron_lm/megatron/mpu/cross_entropy.py
@@ -0,0 +1,110 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+from .initialize import get_model_parallel_group
+from .initialize import get_model_parallel_rank
+from .initialize import get_model_parallel_world_size
+from .utils import VocabUtility
+
+
+class _VocabParallelCrossEntropy(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, vocab_parallel_logits, target):
+
+ # Maximum value along vocab dimension across all GPUs.
+ logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
+ torch.distributed.all_reduce(logits_max,
+ op=torch.distributed.ReduceOp.MAX,
+ group=get_model_parallel_group())
+ # Subtract the maximum value.
+ vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
+
+ # Get the partition's vocab indecies
+ get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
+ partition_vocab_size = vocab_parallel_logits.size()[-1]
+ rank = get_model_parallel_rank()
+ world_size = get_model_parallel_world_size()
+ vocab_start_index, vocab_end_index = get_vocab_range(
+ partition_vocab_size, rank, world_size)
+
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+ masked_target = target.clone() - vocab_start_index
+ masked_target[target_mask] = 0
+
+ # Get predicted-logits = logits[target].
+ # For Simplicity, we convert logits to a 2-D tensor with size
+ # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+ logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
+ device=logits_2d.device)
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+ predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+ predicted_logits = predicted_logits_1d.view_as(target)
+ predicted_logits[target_mask] = 0.0
+ # All reduce is needed to get the chunks from other GPUs.
+ torch.distributed.all_reduce(predicted_logits,
+ op=torch.distributed.ReduceOp.SUM,
+ group=get_model_parallel_group())
+
+ # Sum of exponential of logits along vocab dimension across all GPUs.
+ exp_logits = vocab_parallel_logits
+ torch.exp(vocab_parallel_logits, out=exp_logits)
+ sum_exp_logits = exp_logits.sum(dim=-1)
+ torch.distributed.all_reduce(sum_exp_logits,
+ op=torch.distributed.ReduceOp.SUM,
+ group=get_model_parallel_group())
+
+ # Loss = log(sum(exp(logits))) - predicted-logit.
+ loss = torch.log(sum_exp_logits) - predicted_logits
+
+ # Store softmax, target-mask and masked-target for backward pass.
+ exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+ ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
+
+ return loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ # Retreive tensors from the forward path.
+ softmax, target_mask, masked_target_1d = ctx.saved_tensors
+
+ # All the inputs have softmax as thier gradient.
+ grad_input = softmax
+ # For simplicity, work with the 2D gradient.
+ partition_vocab_size = softmax.size()[-1]
+ grad_2d = grad_input.view(-1, partition_vocab_size)
+
+ # Add the gradient from matching classes.
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
+ device=grad_2d.device)
+ grad_2d[arange_1d, masked_target_1d] -= (
+ 1.0 - target_mask.view(-1).float())
+
+ # Finally elementwise multiplication with the output gradients.
+ grad_input.mul_(grad_output.unsqueeze(dim=-1))
+
+ return grad_input, None
+
+
+def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
+ """Helper function for the cross entropy."""
+ return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
diff --git a/megatron_lm/megatron/mpu/data.py b/megatron_lm/megatron/mpu/data.py
new file mode 100644
index 0000000..84b0af6
--- /dev/null
+++ b/megatron_lm/megatron/mpu/data.py
@@ -0,0 +1,116 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from .initialize import get_model_parallel_group
+from .initialize import get_model_parallel_rank
+from .initialize import get_model_parallel_src_rank
+
+
+_MAX_DATA_DIM = 4
+
+
+def _check_data_types(keys, data, target_dtype):
+ """Check that all the keys have the same target data type."""
+ for key in keys:
+ assert data[key].dtype == target_dtype, '{} has data type {} which '\
+ 'is different than {}'.format(key, data[key].dtype, target_dtype)
+
+
+def _build_key_size_numel_dictionaries(keys, data):
+ """Build the size on rank 0 and broadcast."""
+ max_dim = _MAX_DATA_DIM
+ sizes = [0 for _ in range(max_dim) for _ in keys]
+
+ # Pack the sizes on rank zero.
+ if get_model_parallel_rank() == 0:
+ offset = 0
+ for key in keys:
+ assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
+ size = data[key].size()
+ for i, s in enumerate(size):
+ sizes[i + offset] = s
+ offset += max_dim
+
+ # Move to GPU and broadcast.
+ sizes_cuda = torch.cuda.LongTensor(sizes)
+ torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(),
+ group=get_model_parallel_group())
+
+ # Move back to cpu and unpack.
+ sizes_cpu = sizes_cuda.cpu()
+ key_size = {}
+ key_numel = {}
+ total_numel = 0
+ offset = 0
+ for key in keys:
+ i = 0
+ size = []
+ numel = 1
+ while sizes_cpu[offset + i] > 0:
+ this_size = sizes_cpu[offset + i]
+ size.append(this_size)
+ numel *= this_size
+ i += 1
+ key_size[key] = size
+ key_numel[key] = numel
+ total_numel += numel
+ offset += max_dim
+
+ return key_size, key_numel, total_numel
+
+
+def broadcast_data(keys, data, datatype):
+ """Broadcast data from rank zero of each model parallel group to the
+ members of the same model parallel group.
+
+ Arguments:
+ keys: list of keys in the data disctionary to be broadcasted
+ data: data dictionary of string keys and cpu tensor values.
+ datatype: torch data type of all tensors in data associated
+ with keys.
+ """
+ # Build (key, size) and (key, number of elements) dictionaries along
+ # with the total number of elements on all ranks.
+ key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
+ data)
+
+ # Pack on rank zero.
+ if get_model_parallel_rank() == 0:
+ # Check that all keys have the same data type.
+ _check_data_types(keys, data, datatype)
+ # Flatten the data associated with the keys
+ flatten_data = torch.cat(
+ [data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
+ else:
+ flatten_data = torch.empty(total_numel,
+ device=torch.cuda.current_device(),
+ dtype=datatype)
+
+ # Boradcast
+ torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(),
+ group=get_model_parallel_group())
+
+ # Unpack
+ output = {}
+ offset = 0
+ for key in keys:
+ size = key_size[key]
+ numel = key_numel[key]
+ output[key] = flatten_data.narrow(0, offset, numel).view(size)
+ offset += numel
+
+ return output
diff --git a/megatron_lm/megatron/mpu/grads.py b/megatron_lm/megatron/mpu/grads.py
new file mode 100644
index 0000000..f1e511e
--- /dev/null
+++ b/megatron_lm/megatron/mpu/grads.py
@@ -0,0 +1,127 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parts of the code here are adapted from PyTorch
+# repo: https://github.com/pytorch/pytorch
+
+
+import torch
+from torch._six import inf
+
+try:
+ from apex.multi_tensor_apply import multi_tensor_applier
+ import amp_C
+
+except Exception as e:
+ print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
+
+from .initialize import get_model_parallel_group
+from .initialize import get_model_parallel_rank
+
+
+def l2_grad_clipper(parameters, max_norm):
+ """Efficient L2 norm gradient clipping."""
+
+ overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
+ # Make sure we have an iterable.
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ # Filter parameters with gradients.
+ parameters_with_grads = list(filter(
+ lambda p: p.grad is not None, parameters))
+ # Filter parameters for norm calculations.
+ mp_rank_is_zero = (get_model_parallel_rank() == 0)
+ parameters_for_norm = list(filter(
+ lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads))
+ # Calculate L2 norm.
+ norm, _ = multi_tensor_applier(
+ amp_C.multi_tensor_l2norm,
+ overflow_buf,
+ [parameters_for_norm],
+ False # no per-parameter norm
+ )
+ # Sum across all model parallel GPUs.
+ norm_2 = norm * norm
+ torch.distributed.all_reduce(norm_2,
+ op=torch.distributed.ReduceOp.SUM,
+ group=get_model_parallel_group())
+ total_norm = norm_2.item() ** 0.5
+ # Scale to get max_norm.
+ clip_coef = float(max_norm) / (total_norm + 1.0e-6)
+ grads = [p.grad for p in parameters_with_grads]
+ if clip_coef < 1.0:
+ multi_tensor_applier(
+ amp_C.multi_tensor_scale,
+ overflow_buf,
+ [grads, grads],
+ clip_coef)
+ return total_norm
+
+
+def clip_grad_norm(parameters, max_norm, norm_type=2):
+ """Clips gradient norm of an iterable of parameters.
+
+ This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
+ added functionality to handle model parallel parameters. Note that
+ the gradients are modified in place.
+
+ Arguments:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+
+ Returns:
+ Total norm of the parameters (viewed as a single vector).
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+ if norm_type == inf:
+ total_norm = max(p.grad.data.abs().max() for p in parameters)
+ total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
+ # Take max across all GPUs.
+ torch.distributed.all_reduce(total_norm_cuda,
+ op=torch.distributed.ReduceOp.MAX,
+ group=get_model_parallel_group())
+ total_norm = total_norm_cuda[0].item()
+ clip_coef = max_norm / (total_norm + 1e-6)
+ if clip_coef < 1:
+ for p in parameters:
+ p.grad.data.mul_(clip_coef)
+ #elif norm_type == 2:
+ # total_norm = l2_grad_clipper(parameters, max_norm)
+
+ else:
+ total_norm = 0
+ for p in parameters:
+ if p.model_parallel or (get_model_parallel_rank() == 0):
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ # Sum across all model parallel GPUs.
+ total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
+ torch.distributed.all_reduce(total_norm_cuda,
+ op=torch.distributed.ReduceOp.SUM,
+ group=get_model_parallel_group())
+ total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
+ clip_coef = max_norm / (total_norm + 1e-6)
+ if clip_coef < 1:
+ for p in parameters:
+ p.grad.data.mul_(clip_coef)
+ return total_norm
diff --git a/megatron_lm/megatron/mpu/initialize.py b/megatron_lm/megatron/mpu/initialize.py
new file mode 100644
index 0000000..2238347
--- /dev/null
+++ b/megatron_lm/megatron/mpu/initialize.py
@@ -0,0 +1,162 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Model and data parallel groups."""
+
+import torch
+
+from .utils import ensure_divisibility
+
+
+# Model parallel group that the current rank belongs to.
+_MODEL_PARALLEL_GROUP = None
+# Data parallel group that the current rank belongs to.
+_DATA_PARALLEL_GROUP = None
+
+# These values enable us to change the mpu sizes on the fly.
+_MPU_WORLD_SIZE = None
+_MPU_RANK = None
+
+
+def is_unitialized():
+ """Useful for code segments that may be accessed with or without mpu initialization"""
+ return _DATA_PARALLEL_GROUP is None
+
+
+def initialize_model_parallel(model_parallel_size_):
+ """
+ Initialize model data parallel groups.
+
+ Arguments:
+ model_parallel_size: number of GPUs used to parallelize model.
+
+ Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
+ use 2 GPUs to parallelize the model. The present function will
+ create 4 model parallel groups and 2 data parallel grous as:
+ 4 model parallel groups:
+ [g0, g1], [g2, g3], [g4, g5], [g6, g7]
+ 2 data parallel groups:
+ [g0, g2, g4, g6], [g1, g3, g5, g7]
+ Note that for efficiency, the caller should make sure adjacent ranks
+ are on the same DGX box. For example if we are using 2 DGX-1 boxes
+ with a total of 16 GPUs, rank 0 to 7 belong to the first box and
+ ranks 8 to 15 belong to the second box.
+ """
+ if torch.distributed.get_rank() == 0:
+ print('> initializing model parallel with size {}'.format(
+ model_parallel_size_))
+ # Get world size and rank. Ensure some consistencies.
+ assert torch.distributed.is_initialized()
+ world_size = torch.distributed.get_world_size()
+ model_parallel_size = min(model_parallel_size_, world_size)
+ ensure_divisibility(world_size, model_parallel_size)
+ rank = torch.distributed.get_rank()
+
+ # Build the data parallel groups.
+ global _DATA_PARALLEL_GROUP
+ assert _DATA_PARALLEL_GROUP is None, \
+ 'data parallel group is already initialized'
+ for i in range(model_parallel_size):
+ ranks = range(i, world_size, model_parallel_size)
+ group = torch.distributed.new_group(ranks)
+ if i == (rank % model_parallel_size):
+ _DATA_PARALLEL_GROUP = group
+
+ # Build the model parallel groups.
+ global _MODEL_PARALLEL_GROUP
+ assert _MODEL_PARALLEL_GROUP is None, \
+ 'model parallel group is already initialized'
+ for i in range(world_size // model_parallel_size):
+ ranks = range(i * model_parallel_size,
+ (i + 1) * model_parallel_size)
+ group = torch.distributed.new_group(ranks)
+ if i == (rank // model_parallel_size):
+ _MODEL_PARALLEL_GROUP = group
+
+
+def model_parallel_is_initialized():
+ """Check if model and data parallel groups are initialized."""
+ if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
+ return False
+ return True
+
+
+def get_model_parallel_group():
+ """Get the model parallel group the caller rank belongs to."""
+ assert _MODEL_PARALLEL_GROUP is not None, \
+ 'model parallel group is not initialized'
+ return _MODEL_PARALLEL_GROUP
+
+
+def get_data_parallel_group():
+ """Get the data parallel group the caller rank belongs to."""
+ assert _DATA_PARALLEL_GROUP is not None, \
+ 'data parallel group is not initialized'
+ return _DATA_PARALLEL_GROUP
+
+
+def set_model_parallel_world_size(world_size):
+ """Set the model parallel size"""
+ global _MPU_WORLD_SIZE
+ _MPU_WORLD_SIZE = world_size
+
+
+def get_model_parallel_world_size():
+ """Return world size for the model parallel group."""
+ global _MPU_WORLD_SIZE
+ if _MPU_WORLD_SIZE is not None:
+ return _MPU_WORLD_SIZE
+ return torch.distributed.get_world_size(group=get_model_parallel_group())
+
+
+def set_model_parallel_rank(rank):
+ """Set model parallel rank."""
+ global _MPU_RANK
+ _MPU_RANK = rank
+
+
+def get_model_parallel_rank():
+ """Return my rank for the model parallel group."""
+ global _MPU_RANK
+ if _MPU_RANK is not None:
+ return _MPU_RANK
+ return torch.distributed.get_rank(group=get_model_parallel_group())
+
+
+def get_model_parallel_src_rank():
+ """Calculate the global rank corresponding to a local rank zeor
+ in the model parallel group."""
+ global_rank = torch.distributed.get_rank()
+ local_world_size = get_model_parallel_world_size()
+ return (global_rank // local_world_size) * local_world_size
+
+
+def get_data_parallel_world_size():
+ """Return world size for the data parallel group."""
+ return torch.distributed.get_world_size(group=get_data_parallel_group())
+
+
+def get_data_parallel_rank():
+ """Return my rank for the data parallel group."""
+ return torch.distributed.get_rank(group=get_data_parallel_group())
+
+
+def destroy_model_parallel():
+ """Set the groups to none."""
+ global _MODEL_PARALLEL_GROUP
+ _MODEL_PARALLEL_GROUP = None
+ global _DATA_PARALLEL_GROUP
+ _DATA_PARALLEL_GROUP = None
diff --git a/megatron_lm/megatron/mpu/layers.py b/megatron_lm/megatron/mpu/layers.py
new file mode 100644
index 0000000..b849910
--- /dev/null
+++ b/megatron_lm/megatron/mpu/layers.py
@@ -0,0 +1,369 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parts of the code here are adapted from PyTorch
+# repo: https://github.com/pytorch/pytorch
+
+
+import math
+from deepspeed.runtime.zero.partition_parameters import print_rank_0
+
+import torch
+import torch.nn.functional as F
+import torch.nn.init as init
+from torch.nn.parameter import Parameter
+
+try:
+ from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
+ # Try to use FusedLayerNorm from Apex - this will trigger an error.
+ _ = LayerNorm(8, eps=1e-5)
+
+except Exception as e:
+ print('WARNING: APEX is not installed, using torch.nn.LayerNorm '
+ 'instead of apex.normalization.FusedLayerNorm!')
+ from torch.nn import LayerNorm
+
+from .initialize import get_model_parallel_rank
+from .initialize import get_model_parallel_world_size
+from .mappings import copy_to_model_parallel_region
+from .mappings import gather_from_model_parallel_region
+from .mappings import reduce_from_model_parallel_region
+from .mappings import scatter_to_model_parallel_region
+from .random import get_cuda_rng_tracker
+from .utils import divide
+from .utils import split_tensor_along_last_dim
+from .utils import VocabUtility
+from megatron import get_args
+import deepspeed.runtime.activation_checkpointing.checkpointing as ds_checkpointing
+
+def _initialize_affine_weight_gpu(weight, init_method,
+ partition_dim, stride=1):
+ """Initialize affine weight for model parallel on GPU."""
+
+ weight.model_parallel = True
+ weight.partition_dim = partition_dim
+ weight.partition_stride = stride
+
+ if ds_checkpointing.is_configured():
+ global get_cuda_rng_tracker
+ get_cuda_rng_tracker = ds_checkpointing.get_cuda_rng_tracker
+
+ with get_cuda_rng_tracker().fork():
+ init_method(weight)
+
+
+def _initialize_affine_weight_cpu(weight, output_size, input_size,
+ per_partition_size, partition_dim,
+ init_method, stride=1,
+ return_master_weight=False):
+ """Initialize affine weight for model parallel.
+
+ Build the master weight on all processes and scatter
+ the relevant chunk."""
+
+ weight.model_parallel = True
+ weight.partition_dim = partition_dim
+ weight.partition_stride = stride
+
+ # Initialize master weight
+ master_weight = torch.empty(output_size, input_size,
+ dtype=torch.float,
+ requires_grad=False)
+ init_method(master_weight)
+ args = get_args()
+ master_weight = master_weight.to(dtype=args.params_dtype)
+
+ # Split and copy
+ per_partition_per_stride_size = divide(per_partition_size, stride)
+ weight_list = torch.split(master_weight, per_partition_per_stride_size,
+ dim=partition_dim)
+ rank = get_model_parallel_rank()
+ world_size = get_model_parallel_world_size()
+ my_weight_list = weight_list[rank::world_size]
+
+ with torch.no_grad():
+ torch.cat(my_weight_list, dim=partition_dim, out=weight)
+ if return_master_weight:
+ return master_weight
+ return None
+
+
+class VocabParallelEmbedding(torch.nn.Module):
+ """Embedding parallelized in the vocabulary dimension.
+
+ This is mainly adapted from torch.nn.Embedding and all the default
+ values are kept.
+ Arguments:
+ num_embeddings: vocabulary size.
+ embedding_dim: size of hidden state.
+ init_method: method to initialize weights.
+ """
+
+ def __init__(self, num_embeddings, embedding_dim,
+ init_method=init.xavier_normal_):
+ super(VocabParallelEmbedding, self).__init__()
+ # Keep the input dimensions.
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ # Set the detauls for compatibility.
+ self.padding_idx = None
+ self.max_norm = None
+ self.norm_type = 2.
+ self.scale_grad_by_freq = False
+ self.sparse = False
+ self._weight = None
+ self.model_parallel_size = get_model_parallel_world_size()
+ # Divide the weight matrix along the vocaburaly dimension.
+ self.vocab_start_index, self.vocab_end_index = \
+ VocabUtility.vocab_range_from_global_vocab_size(
+ self.num_embeddings, get_model_parallel_rank(),
+ self.model_parallel_size)
+ self.num_embeddings_per_partition = self.vocab_end_index - \
+ self.vocab_start_index
+
+ # Allocate weights and initialize.
+ args = get_args()
+ if args.use_cpu_initialization:
+ self.weight = Parameter(torch.empty(
+ self.num_embeddings_per_partition, self.embedding_dim,
+ dtype=args.params_dtype))
+ _initialize_affine_weight_cpu(
+ self.weight, self.num_embeddings, self.embedding_dim,
+ self.num_embeddings_per_partition, 0, init_method)
+ else:
+ self.weight = Parameter(torch.empty(
+ self.num_embeddings_per_partition, self.embedding_dim,
+ device=torch.cuda.current_device(), dtype=args.params_dtype))
+ _initialize_affine_weight_gpu(self.weight, init_method,
+ partition_dim=0, stride=1)
+
+ def forward(self, input_):
+ if self.model_parallel_size > 1:
+ # Build the mask.
+ input_mask = (input_ < self.vocab_start_index) | \
+ (input_ >= self.vocab_end_index)
+ # Mask the input.
+ masked_input = input_.clone() - self.vocab_start_index
+ masked_input[input_mask] = 0
+ else:
+ masked_input = input_
+ # Get the embeddings.
+ output_parallel = F.embedding(masked_input, self.weight,
+ self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq,
+ self.sparse)
+ # Mask the output embedding.
+ if self.model_parallel_size > 1:
+ output_parallel[input_mask, :] = 0.0
+ # Reduce across all the model parallel GPUs.
+ output = reduce_from_model_parallel_region(output_parallel)
+ return output
+
+
+class ColumnParallelLinear(torch.nn.Module):
+ """Linear layer with column parallelism.
+
+ The linear layer is defined as Y = XA + b. A is parallelized along
+ its second dimension as A = [A_1, ..., A_p].
+
+ Arguments:
+ input_size: first dimension of matrix A.
+ output_size: second dimension of matrix A.
+ bias: If true, add bias
+ gather_output: If true, call all-gether on output and make Y avaiable
+ to all GPUs, otherwise, every GPU will have its output
+ which is Y_i = XA_i
+ init_method: method to initialize weights. Note that bias is always set
+ to zero.
+ stride: For the strided linear layers.
+ keep_master_weight_for_test: This was added for testing and should be
+ set to False. It returns the master weight
+ used for initialization.
+ skip_bias_add: This was added to enable performance optimations where bias
+ can be fused with other elementwise operations. we skip
+ adding bias but instead return it.
+ """
+
+ def __init__(self, input_size, output_size, bias=True, gather_output=True,
+ init_method=init.xavier_normal_, stride=1,
+ keep_master_weight_for_test=False,
+ skip_bias_add=False):
+ super(ColumnParallelLinear, self).__init__()
+
+ # Keep input parameters
+ self.input_size = input_size
+ self.output_size = output_size
+ self.gather_output = gather_output
+ # Divide the weight matrix along the last dimension.
+ world_size = get_model_parallel_world_size()
+ self.output_size_per_partition = divide(output_size, world_size)
+ self.skip_bias_add = skip_bias_add
+ if not bias:
+ self.skip_bias_add = True
+
+ # Parameters.
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
+ # we allocate the transpose.
+ # Initialize weight.
+ args = get_args()
+ if args.use_cpu_initialization:
+ self.weight = Parameter(torch.empty(self.output_size_per_partition,
+ self.input_size,
+ dtype=args.params_dtype))
+ self.master_weight = _initialize_affine_weight_cpu(
+ self.weight, self.output_size, self.input_size,
+ self.output_size_per_partition, 0, init_method,
+ stride=stride, return_master_weight=keep_master_weight_for_test)
+ else:
+ self.weight = Parameter(torch.empty(
+ self.output_size_per_partition, self.input_size,
+ device=torch.cuda.current_device(), dtype=args.params_dtype))
+ _initialize_affine_weight_gpu(self.weight, init_method,
+ partition_dim=0, stride=stride)
+
+ if bias:
+ if args.use_cpu_initialization:
+ self.bias = Parameter(torch.empty(
+ self.output_size_per_partition, dtype=args.params_dtype))
+ else:
+ self.bias = Parameter(torch.empty(
+ self.output_size_per_partition,
+ device=torch.cuda.current_device(),
+ dtype=args.params_dtype))
+ self.bias.model_parallel = True
+ self.bias.partition_dim = 0
+ self.bias.partition_stride = stride
+ # Always initialize bias to zero.
+ with torch.no_grad():
+ self.bias.zero_()
+ else:
+ self.register_parameter('bias', None)
+
+
+
+ def forward(self, input_):
+ # Set up backprop all-reduce.
+ input_parallel = copy_to_model_parallel_region(input_)
+ # Matrix multiply.
+ bias = self.bias if not self.skip_bias_add else None
+ output_parallel = F.linear(input_parallel, self.weight, bias)
+ if self.gather_output:
+ # All-gather across the partitions.
+ output = gather_from_model_parallel_region(output_parallel)
+ else:
+ output = output_parallel
+ output_bias = self.bias if self.skip_bias_add else None
+ return output, output_bias
+
+
+class RowParallelLinear(torch.nn.Module):
+ """Linear layer with row parallelism.
+
+ The linear layer is defined as Y = XA + b. A is parallelized along
+ its first dimension and X along its second dimension as:
+ - -
+ | A_1 |
+ | . |
+ A = | . | X = [X_1, ..., X_p]
+ | . |
+ | A_p |
+ - -
+ Arguments:
+ input_size: first dimension of matrix A.
+ output_size: second dimension of matrix A.
+ bias: If true, add bias. Note that bias is not parallelized.
+ input_is_parallel: If true, we assume that the input is already
+ split across the GPUs and we do not split
+ again.
+ init_method: method to initialize weights. Note that bias is always set
+ to zero.
+ stride: For the strided linear layers.
+ keep_master_weight_for_test: This was added for testing and should be
+ set to False. It returns the master weights
+ used for initialization.
+ skip_bias_add: This was added to enable performance optimations where bias
+ can be fused with other elementwise operations. we skip
+ adding bias but instead return it.
+ """
+
+ def __init__(self, input_size, output_size, bias=True,
+ input_is_parallel=False,
+ init_method=init.xavier_normal_, stride=1,
+ keep_master_weight_for_test=False,
+ skip_bias_add=False):
+ super(RowParallelLinear, self).__init__()
+
+ # Keep input parameters
+ self.input_size = input_size
+ self.output_size = output_size
+ self.input_is_parallel = input_is_parallel
+ # Divide the weight matrix along the last dimension.
+ world_size = get_model_parallel_world_size()
+ self.input_size_per_partition = divide(input_size, world_size)
+ self.skip_bias_add = skip_bias_add
+
+ # Parameters.
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
+ # we allocate the transpose.
+ # Initialize weight.
+ args = get_args()
+ if args.use_cpu_initialization:
+ self.weight = Parameter(torch.empty(self.output_size,
+ self.input_size_per_partition,
+ dtype=args.params_dtype))
+ self.master_weight = _initialize_affine_weight_cpu(
+ self.weight, self.output_size, self.input_size,
+ self.input_size_per_partition, 1, init_method,
+ stride=stride, return_master_weight=keep_master_weight_for_test)
+ else:
+ self.weight = Parameter(torch.empty(
+ self.output_size, self.input_size_per_partition,
+ device=torch.cuda.current_device(), dtype=args.params_dtype))
+ _initialize_affine_weight_gpu(self.weight, init_method,
+ partition_dim=1, stride=stride)
+ if bias:
+ if args.use_cpu_initialization:
+ self.bias = Parameter(torch.empty(self.output_size,
+ dtype=args.params_dtype))
+ else:
+ self.bias = Parameter(torch.empty(
+ self.output_size, device=torch.cuda.current_device(),
+ dtype=args.params_dtype))
+ # Always initialize bias to zero.
+ with torch.no_grad():
+ self.bias.zero_()
+ else:
+ self.register_parameter('bias', None)
+
+
+
+ def forward(self, input_):
+ # Set up backprop all-reduce.
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ input_parallel = scatter_to_model_parallel_region(input_)
+ # Matrix multiply.
+ output_parallel = F.linear(input_parallel, self.weight)
+ # All-reduce across all the partitions.
+ output_ = reduce_from_model_parallel_region(output_parallel)
+ if not self.skip_bias_add:
+ output = output_ + self.bias if self.bias is not None else output_
+ output_bias = None
+ else:
+ output = output_
+ output_bias = self.bias
+ return output, output_bias
diff --git a/megatron_lm/megatron/mpu/mappings.py b/megatron_lm/megatron/mpu/mappings.py
new file mode 100644
index 0000000..291c499
--- /dev/null
+++ b/megatron_lm/megatron/mpu/mappings.py
@@ -0,0 +1,157 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank
+from .utils import split_tensor_along_last_dim
+
+
+def _reduce(input_):
+ """All-reduce the the input tensor across model parallel group."""
+
+ # Bypass the function if we are using only 1 GPU.
+ if get_model_parallel_world_size()==1:
+ return input_
+
+ # All-reduce.
+ torch.distributed.all_reduce(input_, group=get_model_parallel_group())
+
+ return input_
+
+
+def _split(input_):
+ """Split the tensor along its last dimension and keep the
+ corresponding slice."""
+
+ world_size = get_model_parallel_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size==1:
+ return input_
+
+ # Split along last dimension.
+ input_list = split_tensor_along_last_dim(input_, world_size)
+
+ # Note: torch.split does not create contiguous tensors by default.
+ rank = get_model_parallel_rank()
+ output = input_list[rank].contiguous()
+
+ return output
+
+
+def _gather(input_):
+ """Gather tensors and concatinate along the last dimension."""
+
+ world_size = get_model_parallel_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size==1:
+ return input_
+
+ # Size and dimension.
+ last_dim = input_.dim() - 1
+ rank = get_model_parallel_rank()
+
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+ tensor_list[rank] = input_
+ torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
+
+ # Note: torch.cat already creates a contiguous tensor.
+ output = torch.cat(tensor_list, dim=last_dim).contiguous()
+
+ return output
+
+
+class _CopyToModelParallelRegion(torch.autograd.Function):
+ """Pass the input to the model parallel region."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return input_
+
+ @staticmethod
+ def forward(ctx, input_):
+ return input_
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _reduce(grad_output)
+
+
+class _ReduceFromModelParallelRegion(torch.autograd.Function):
+ """All-redcue the input from the model parallel region."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _reduce(input_)
+
+ @staticmethod
+ def forward(ctx, input_):
+ return _reduce(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output
+
+
+class _ScatterToModelParallelRegion(torch.autograd.Function):
+ """Split the input and keep only the corresponding chuck to the rank."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _split(input_)
+
+ @staticmethod
+ def forward(ctx, input_):
+ return _split(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _gather(grad_output)
+
+
+class _GatherFromModelParallelRegion(torch.autograd.Function):
+ """Gather the input from model parallel region and concatinate."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _gather(input_)
+
+ @staticmethod
+ def forward(ctx, input_):
+ return _gather(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split(grad_output)
+
+
+# -----------------
+# Helper functions.
+# -----------------
+
+def copy_to_model_parallel_region(input_):
+ return _CopyToModelParallelRegion.apply(input_)
+
+
+def reduce_from_model_parallel_region(input_):
+ return _ReduceFromModelParallelRegion.apply(input_)
+
+
+def scatter_to_model_parallel_region(input_):
+ return _ScatterToModelParallelRegion.apply(input_)
+
+
+def gather_from_model_parallel_region(input_):
+ return _GatherFromModelParallelRegion.apply(input_)
diff --git a/megatron_lm/megatron/mpu/random.py b/megatron_lm/megatron/mpu/random.py
new file mode 100644
index 0000000..261f0a4
--- /dev/null
+++ b/megatron_lm/megatron/mpu/random.py
@@ -0,0 +1,319 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parts of the code here are adapted from PyTorch
+# repo: https://github.com/pytorch/pytorch
+
+import contextlib
+
+import torch
+from torch import _C
+from torch.cuda import _lazy_call, device as device_ctx_manager
+from torch.utils.checkpoint import detach_variable
+
+from megatron import get_args
+from megatron.memory import allocate_mem_buff
+
+from .initialize import get_data_parallel_rank
+from .initialize import get_model_parallel_group
+from .initialize import get_model_parallel_rank
+from .initialize import get_model_parallel_world_size
+
+
+# Default name for the model parallel rng tracker.
+_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
+
+
+# Whether apply model parallelsim to checkpointed hidden states.
+_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
+
+
+def init_checkpointed_activations_memory_buffer():
+ """Initializ the memory buffer for the checkpointed activations."""
+ args = get_args()
+
+ per_layer = args.batch_size * args.max_position_embeddings * \
+ args.hidden_size // args.model_parallel_size
+ assert args.num_layers % args.checkpoint_num_layers == 0, \
+ 'number of layers is not divisible by checkpoint-num-layers'
+ num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
+ numel = per_layer * num_checkpointer_layers
+ dtype = torch.half
+ if not args.fp16:
+ dtype = torch.float
+
+ global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
+ assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
+ 'checkpointed activations memory buffer is already allocated.'
+ _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
+ 'checkpointed activations', numel, dtype, track_usage=False)
+
+
+def reset_checkpointed_activations_memory_buffer():
+ """Reset the memory used for checkpointing."""
+ if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
+ _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
+
+
+def _set_cuda_rng_state(new_state, device=-1):
+ """Sets the random number generator state of the current GPU.
+
+ Argumentss:
+ new_state (torch.ByteTensor): The desired state
+ This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
+ with a single change: the input state is not cloned. Cloning caused
+ major performance issues for +4 GPU cases.
+ """
+ if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
+ # older PyTorch
+ def cb():
+ with device_ctx_manager(device):
+ _C._cuda_setRNGState(new_state)
+ else:
+ # newer PyTorch
+ if device == -1:
+ device = torch.device('cuda')
+ elif isinstance(device, str):
+ device = torch.device(device)
+ elif isinstance(device, int):
+ device = torch.device('cuda', device)
+
+ def cb():
+ idx = device.index
+ if idx is None:
+ idx = torch.cuda.current_device()
+ default_generator = torch.cuda.default_generators[idx]
+ default_generator.set_state(new_state)
+
+ _lazy_call(cb)
+
+
+def split_tensor_into_1d_equal_chunks(tensor):
+ """Break a tensor into equal 1D chunks."""
+ data = tensor.view(-1)
+ partition_size = torch.numel(data) // get_model_parallel_world_size()
+ start_index = partition_size * get_model_parallel_rank()
+ end_index = start_index + partition_size
+ return data[start_index:end_index]
+
+
+def gather_split_1d_tensor(tensor):
+ """Opposite of above function, gather values from model parallel ranks."""
+ world_size = get_model_parallel_world_size()
+ numel = torch.numel(tensor)
+ numel_gathered = world_size * numel
+ gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False)
+ chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
+ torch.distributed.all_gather(chunks, tensor,
+ group=get_model_parallel_group())
+ return gathered
+
+
+class CudaRNGStatesTracker:
+ """Tracker for the cuda RNG states.
+
+ Using the `add` method, a cuda rng state is initialized based on
+ the input `seed` and is assigned to `name`. Later, by forking the
+ rng state, we can perform operations and return to our starting
+ cuda state.
+ """
+
+ def __init__(self):
+ # Map from a string name to the cuda rng state.
+ self.states_ = {}
+ # Seeds are just for book keeping and ensure no seed is set twice.
+ self.seeds_ = set()
+
+ def reset(self):
+ """Set to the initial state (no tracker)."""
+ self.states_ = {}
+ self.seeds_ = set()
+
+ def get_states(self):
+ """Get rng states. Copy the dictionary so we have direct
+ pointers to the states, not just a pointer to the dictionary."""
+ states = {}
+ for name in self.states_:
+ states[name] = self.states_[name]
+ return states
+
+ def set_states(self, states):
+ """Set the rng states. For efficiency purposes, we do not check
+ the size of seed for compatibility."""
+ self.states_ = states
+
+ def add(self, name, seed):
+ """Track the rng state."""
+ # Check seed is not already used.
+ if seed in self.seeds_:
+ raise Exception('seed {} already exists'.format(seed))
+ self.seeds_.add(seed)
+ # Check that state is not already defined.
+ if name in self.states_:
+ raise Exception('cuda rng state {} already exists'.format(name))
+ # Get the current rng state.
+ orig_rng_state = torch.cuda.get_rng_state()
+ # Set the new state and store it.
+ torch.cuda.manual_seed(seed)
+ self.states_[name] = torch.cuda.get_rng_state()
+ # Reset rng state to what it was.
+ _set_cuda_rng_state(orig_rng_state)
+
+ @contextlib.contextmanager
+ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
+ """Fork the cuda rng state, perform operations, and exit with
+ the original state."""
+ # Check if we have added the state
+ if name not in self.states_:
+ raise Exception('cuda rng state {} is not added'.format(name))
+ # Store current rng state.
+ orig_cuda_rng_state = torch.cuda.get_rng_state()
+ # Set rng state to the desired one
+ _set_cuda_rng_state(self.states_[name])
+ # Do the stuff we wanted to do.
+ try:
+ yield
+ finally:
+ # Update the current rng state for later use.
+ self.states_[name] = torch.cuda.get_rng_state()
+ # And set the state to the original state we started with.
+ _set_cuda_rng_state(orig_cuda_rng_state)
+
+
+# RNG tracker object.
+_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
+
+
+def get_cuda_rng_tracker():
+ """Get cuda rng tracker."""
+ return _CUDA_RNG_STATE_TRACKER
+
+
+def model_parallel_cuda_manual_seed(seed):
+ """Initialize model parallel cuda seed.
+
+ This function should be called after the model parallel is
+ initialized. Also, no torch.cuda.manual_seed should be called
+ after this function. Basically, this is replacement for that
+ function.
+ Two set of RNG states are tracked:
+ default state: This is for data parallelism and is the same among a
+ set of model parallel GPUs but different across
+ different model paralle groups. This is used for
+ example for dropout in the non-model-parallel regions.
+ model-parallel state: This state is different among a set of model
+ parallel GPUs, but the same across data parallel
+ groups. This is used for example for dropout in
+ model parallel regions.
+ """
+ # 2718 is just for fun and any POSITIVE value will work.
+ offset = seed + 2718
+ model_parallel_seed = offset + get_model_parallel_rank()
+ # Data parallel gets the original sedd.
+ data_parallel_seed = seed
+
+ if torch.distributed.get_rank() == 0:
+ print('> initializing model parallel cuda seeds on global rank {}, '
+ 'model parallel rank {}, and data parallel rank {} with '
+ 'model parallel seed: {} and data parallel seed: {}'.format(
+ torch.distributed.get_rank(), get_model_parallel_rank(),
+ get_data_parallel_rank(), model_parallel_seed,
+ data_parallel_seed), flush=True)
+ _CUDA_RNG_STATE_TRACKER.reset()
+ # Set the default state.
+ torch.cuda.manual_seed(data_parallel_seed)
+ # and model parallel state.
+ _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
+ model_parallel_seed)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ """This function is adapted from torch.utils.checkpoint with
+ two main changes:
+ 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
+ 2) the states in the model parallel tracker are also properly
+ tracked/set/reset.
+ """
+ @staticmethod
+ def forward(ctx, run_function, *args):
+ ctx.run_function = run_function
+
+ # Copy the rng states.
+ ctx.fwd_cpu_rng_state = torch.get_rng_state()
+ ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
+ ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ with torch.no_grad():
+ outputs = run_function(*args)
+
+ # Divide hidden states across model parallel group and only keep
+ # the chunk corresponding to the current rank.
+ if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
+ ctx.input_0_shape = args[0].data.shape
+ args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
+ args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
+ args[0].data)
+
+ # Store everything.
+ ctx.save_for_backward(*args)
+
+
+ return outputs
+
+ @staticmethod
+ def backward(ctx, *args):
+ if not torch.autograd._is_checkpoint_valid():
+ raise RuntimeError("Checkpointing is not compatible with .grad(), "
+ "please use .backward() if possible")
+ inputs = ctx.saved_tensors
+ if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
+ inputs[0].data = gather_split_1d_tensor(inputs[0].data)
+ inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
+
+ # Store the current states.
+ bwd_cpu_rng_state = torch.get_rng_state()
+ bwd_cuda_rng_state = torch.cuda.get_rng_state()
+ bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ # Set the states to what it used to be before the forward pass.
+ torch.set_rng_state(ctx.fwd_cpu_rng_state)
+ _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
+
+ # Compute the forward pass.
+ detached_inputs = detach_variable(inputs)
+ with torch.enable_grad():
+ outputs = ctx.run_function(*detached_inputs)
+
+ # Set the states back to what it was at the start of this function.
+ torch.set_rng_state(bwd_cpu_rng_state)
+ _set_cuda_rng_state(bwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
+
+ if isinstance(outputs, torch.Tensor):
+ outputs = (outputs,)
+ torch.autograd.backward(outputs, args)
+ grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
+ for inp in detached_inputs)
+ return (None,) + grads
+
+
+def checkpoint(function, *args):
+ """Checkpoint a model or part of the model.
+ This has been directly copied from torch.utils.checkpoint."""
+ return CheckpointFunction.apply(function, *args)
diff --git a/megatron_lm/megatron/mpu/tests/__init__.py b/megatron_lm/megatron/mpu/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/megatron_lm/megatron/mpu/tests/__init__.py
diff --git a/megatron_lm/megatron/mpu/tests/commons.py b/megatron_lm/megatron/mpu/tests/commons.py
new file mode 100644
index 0000000..5e7a186
--- /dev/null
+++ b/megatron_lm/megatron/mpu/tests/commons.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import random
+import numpy
+import torch
+
+import mpu
+
+
+class IdentityLayer(torch.nn.Module):
+ def __init__(self, size, scale=1.0):
+ super(IdentityLayer, self).__init__()
+ self.weight = torch.nn.Parameter(scale * torch.randn(size))
+
+ def forward(self):
+ return self.weight
+
+
+def set_random_seed(seed):
+ """Set random seed for reproducability."""
+ random.seed(seed)
+ numpy.random.seed(seed)
+ torch.manual_seed(seed)
+ mpu.model_parallel_cuda_manual_seed(seed)
+
+
+def initialize_distributed(backend='nccl'):
+ """Initialize torch.distributed."""
+ # Get local rank in case it is provided.
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--local_rank', type=int, default=None,
+ help='local rank passed from distributed launcher')
+ args = parser.parse_args()
+ local_rank = args.local_rank
+
+ # Get rank and world size.
+ rank = int(os.getenv('RANK', '0'))
+ world_size = int(os.getenv("WORLD_SIZE", '1'))
+
+ print('> initializing torch.distributed with local rank: {}, '
+ 'rank: {}, world size: {}'.format(local_rank, rank, world_size))
+
+ # Set the device id.
+ device = rank % torch.cuda.device_count()
+ if local_rank is not None:
+ device = local_rank
+ torch.cuda.set_device(device)
+
+ # Call the init process.
+ init_method = 'tcp://'
+ master_ip = os.getenv('MASTER_ADDR', 'localhost')
+ master_port = os.getenv('MASTER_PORT', '6000')
+ init_method += master_ip + ':' + master_port
+ torch.distributed.init_process_group(
+ backend=backend,
+ world_size=world_size,
+ rank=rank,
+ init_method=init_method)
+
+
+def print_separator(message):
+ torch.distributed.barrier()
+ filler_len = (78 - len(message)) // 2
+ filler = '-' * filler_len
+ string = '\n' + filler + ' {} '.format(message) + filler
+ if torch.distributed.get_rank() == 0:
+ print(string, flush=True)
+ torch.distributed.barrier()
diff --git a/megatron_lm/megatron/mpu/tests/test_cross_entropy.py b/megatron_lm/megatron/mpu/tests/test_cross_entropy.py
new file mode 100644
index 0000000..41c22fc
--- /dev/null
+++ b/megatron_lm/megatron/mpu/tests/test_cross_entropy.py
@@ -0,0 +1,108 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from commons import set_random_seed
+from commons import IdentityLayer
+from commons import print_separator
+from commons import initialize_distributed
+from mpu.cross_entropy import vocab_parallel_cross_entropy
+import mpu
+import torch.nn.functional as F
+import torch
+import random
+import sys
+sys.path.append("../..")
+
+
+def torch_cross_entropy(batch_size, seq_length, vocab_size,
+ logits_scale, seed):
+ set_random_seed(seed)
+ identity = IdentityLayer((batch_size, seq_length, vocab_size),
+ scale=logits_scale).cuda()
+ logits = identity()
+ target = torch.cuda.LongTensor(
+ size=(batch_size, seq_length)).random_(0, vocab_size)
+ loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
+ target.view(-1),
+ reduction='none').view_as(target).mean()
+ loss.backward()
+ return loss, identity.weight.grad
+
+
+def mpu_cross_entropy(batch_size, seq_length, vocab_size,
+ logits_scale, seed):
+ set_random_seed(seed)
+ identity = IdentityLayer((batch_size, seq_length, vocab_size),
+ scale=logits_scale).cuda()
+ logits = identity()
+ logits_parallel = mpu.scatter_to_model_parallel_region(logits)
+ target = torch.cuda.LongTensor(
+ size=(batch_size, seq_length)).random_(0, vocab_size)
+ loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
+ loss.backward()
+ return loss, identity.weight.grad
+
+
+def test_cross_entropy(model_parallel_size):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing cross entropy with model parallel size {} ...'.
+ format(model_parallel_size))
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ batch_size = 13
+ seq_length = 17
+ vocab_size_per_partition = 11
+ logits_scale = 1000.0
+ vocab_size = vocab_size_per_partition * model_parallel_size
+ seed = 1234
+
+ loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
+ vocab_size, logits_scale,
+ seed)
+ loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length,
+ vocab_size, logits_scale,
+ seed)
+
+ error = loss_torch.sub_(loss_mpu).abs().max()
+ print(' max error in loss on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ error = grad_torch.sub_(grad_mpu).abs().max()
+ print(' max error in grad on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print('>> passed the test :-)')
+
+
+if __name__ == '__main__':
+
+ initialize_distributed()
+ world_size = torch.distributed.get_world_size()
+
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ print_separator('test cross entropy')
+ test_cross_entropy(model_parallel_size)
+ model_parallel_size *= 2
diff --git a/megatron_lm/megatron/mpu/tests/test_data.py b/megatron_lm/megatron/mpu/tests/test_data.py
new file mode 100644
index 0000000..612d841
--- /dev/null
+++ b/megatron_lm/megatron/mpu/tests/test_data.py
@@ -0,0 +1,88 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from commons import print_separator
+from commons import initialize_distributed
+from mpu import data as data_utils
+import mpu
+import torch
+import functools
+import operator
+import sys
+sys.path.append("../..")
+
+
+def test_boradcast_data(model_parallel_size):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing boradcast_data with model parallel size {} ...'.
+ format(model_parallel_size))
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ torch.manual_seed(1234 + mpu.get_data_parallel_rank())
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ key_size_t = {'key1': [7, 11],
+ 'key2': [8, 2, 1],
+ 'key3': [13],
+ 'key4': [5, 1, 2],
+ 'key5': [5, 12]}
+ keys = list(key_size_t.keys())
+
+ data = {}
+ data_t = {}
+ for key in key_size_t:
+ data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
+ data_t[key] = data[key].clone()
+ data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
+ data_t['keyX'] = data['keyX'].clone()
+ if mpu.get_model_parallel_rank() != 0:
+ data = None
+
+ data_utils._check_data_types(keys, data_t, torch.int64)
+ key_size, key_numel, \
+ total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
+ for key in keys:
+ assert key_size[key] == key_size_t[key]
+ total_numel_t = 0
+ for key in keys:
+ target_size = functools.reduce(operator.mul, key_size_t[key], 1)
+ assert key_numel[key] == target_size
+ total_numel_t += target_size
+ assert total_numel == total_numel_t
+
+ data_b = data_utils.broadcast_data(keys, data, torch.int64)
+ for key in keys:
+ tensor = data_t[key].cuda()
+ assert data_b[key].sub(tensor).abs().max() == 0
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print('>> passed the test :-)')
+
+
+if __name__ == '__main__':
+
+ initialize_distributed()
+ world_size = torch.distributed.get_world_size()
+
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ print_separator('test test boradcast data')
+ test_boradcast_data(model_parallel_size)
+ model_parallel_size *= 2
diff --git a/megatron_lm/megatron/mpu/tests/test_initialize.py b/megatron_lm/megatron/mpu/tests/test_initialize.py
new file mode 100644
index 0000000..2a023a3
--- /dev/null
+++ b/megatron_lm/megatron/mpu/tests/test_initialize.py
@@ -0,0 +1,95 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from commons import print_separator
+from commons import initialize_distributed
+import mpu
+import torch
+import sys
+sys.path.append("../..")
+
+
+def test_initialize_model_parallel(model_parallel_size):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing initialize_model_parallel with size {} ...'.format(
+ model_parallel_size))
+ model_parallel_size_ = min(model_parallel_size,
+ torch.distributed.get_world_size())
+ assert not mpu.model_parallel_is_initialized()
+ mpu.initialize_model_parallel(model_parallel_size_)
+ assert mpu.model_parallel_is_initialized()
+
+ # Checks.
+ def check(group, world_size, rank):
+ assert world_size == torch.distributed.get_world_size(group=group)
+ assert rank == torch.distributed.get_rank(group=group)
+
+ # Model parallel.
+ world_size = model_parallel_size_
+ rank = torch.distributed.get_rank() % model_parallel_size_
+ assert world_size == mpu.get_model_parallel_world_size()
+ assert rank == mpu.get_model_parallel_rank()
+ check(mpu.get_model_parallel_group(), world_size, rank)
+
+ # Data parallel.
+ world_size = torch.distributed.get_world_size() // model_parallel_size_
+ rank = torch.distributed.get_rank() // model_parallel_size
+ assert world_size == mpu.get_data_parallel_world_size()
+ assert rank == mpu.get_data_parallel_rank()
+ check(mpu.get_data_parallel_group(), world_size, rank)
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print('>> passed the test :-)')
+
+
+def test_get_model_parallel_src_rank(model_parallel_size_):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing get_model_parallel_src_rank with size {} ...'.format(
+ model_parallel_size_))
+ model_parallel_size = min(model_parallel_size_,
+ torch.distributed.get_world_size())
+ assert not mpu.model_parallel_is_initialized()
+ mpu.initialize_model_parallel(model_parallel_size)
+ assert mpu.model_parallel_is_initialized()
+
+ # Checks
+ src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
+ assert mpu.get_model_parallel_src_rank() == src_rank
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print('>> passed the test :-)')
+
+
+if __name__ == '__main__':
+
+ initialize_distributed()
+ world_size = torch.distributed.get_world_size()
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ print_separator('test initialize model parallel')
+ test_initialize_model_parallel(model_parallel_size)
+ print_separator('test model parallel source rank')
+ test_get_model_parallel_src_rank(model_parallel_size)
+ model_parallel_size *= 2
diff --git a/megatron_lm/megatron/mpu/tests/test_layers.py b/megatron_lm/megatron/mpu/tests/test_layers.py
new file mode 100644
index 0000000..a7f2d9c
--- /dev/null
+++ b/megatron_lm/megatron/mpu/tests/test_layers.py
@@ -0,0 +1,530 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mpu import layers
+from commons import set_random_seed
+from commons import print_separator
+from commons import initialize_distributed
+import mpu
+from torch.nn.parameter import Parameter
+import torch.nn.init as init
+import torch
+import random
+import sys
+sys.path.append("../..")
+
+
+def test_parallel_embedding(model_parallel_size):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing parallel embedding with model parallel size {} ...'.
+ format(model_parallel_size))
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ batch_size = 17
+ seq_length = 23
+ vocab_size = 48
+ hidden_size = 16
+ seed = 1236
+
+ set_random_seed(123)
+ input_data = torch.LongTensor(
+ size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
+ loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
+
+ set_random_seed(seed)
+ embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
+
+ output = embedding_original(input_data)
+ loss_original = torch.mul(output, loss_weight).sum()
+ loss_original.backward()
+
+ set_random_seed(seed)
+ embedding_parallel = layers.ParallelEmbedding(
+ vocab_size, hidden_size, init_method=init.normal_).cuda()
+ output = embedding_parallel(input_data)
+ loss_parallel = torch.mul(output, loss_weight).sum()
+ loss_parallel.backward()
+
+ set_random_seed(seed)
+ embedding_vocab_parallel = layers.VocabParallelEmbedding(
+ vocab_size, hidden_size, init_method=init.normal_).cuda()
+ output = embedding_vocab_parallel(input_data)
+ loss_vocab_parallel = torch.mul(output, loss_weight).sum()
+ loss_vocab_parallel.backward()
+
+ torch.distributed.barrier()
+ error = loss_parallel.sub(loss_original).abs()
+ print(' error in loss (parallel) on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-12, 'error: {}'.format(error)
+
+ torch.distributed.barrier()
+ error = loss_vocab_parallel.sub(loss_original).abs()
+ print(' error in loss (vocab parallel) on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-12, 'error: {}'.format(error)
+
+ weight_grad_orig = torch.split(embedding_original.weight.grad,
+ hidden_size // model_parallel_size,
+ 1)[mpu.get_model_parallel_rank()]
+ error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
+ print(' error in grad (parallel) on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-12, 'error: {}'.format(error)
+
+ weight_grad_orig = torch.split(embedding_original.weight.grad,
+ vocab_size // model_parallel_size,
+ 0)[mpu.get_model_parallel_rank()]
+ error = embedding_vocab_parallel.weight.grad.sub(
+ weight_grad_orig).abs().max()
+ print(' error in grad (vocab parallel) on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-12, 'error: {}'.format(error)
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print('>> passed the test :-)')
+
+
+def test_initialize_affine_weight(model_parallel_size):
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ if torch.distributed.get_rank() == 0:
+ print('> testing initialize_affine_weight with model parallel '
+ 'size: {}'.format(model_parallel_size))
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ seed = 12345
+ input_size_coeff = 13
+ input_size = input_size_coeff * model_parallel_size
+ output_size_coeff = 17
+ output_size = output_size_coeff * model_parallel_size
+
+ # ---------------
+ # Column parallel
+ # ---------------
+ weight = torch.empty(output_size_coeff, input_size)
+ set_random_seed(seed)
+ layers._initialize_affine_weight(weight, output_size, input_size,
+
+ output_size_coeff, 0,
+ torch.nn.init.normal_)
+ # Target.
+ set_random_seed(seed)
+ master_weight = torch.empty(output_size, input_size)
+ torch.nn.init.normal_(master_weight)
+ rank = mpu.get_model_parallel_rank()
+ my_weight = torch.split(master_weight, output_size_coeff,
+ dim=0)[rank].contiguous().clone()
+
+ # Compare.
+ error = weight.sub(my_weight).abs().max()
+ torch.distributed.barrier()
+ print(' column parallel max error (should be zero) on global rank '
+ '{}: {}'.format(torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ # ------------
+ # Row parallel
+ # ------------
+ weight = torch.empty(output_size, input_size_coeff)
+ set_random_seed(seed)
+ mpu.layers._initialize_affine_weight(weight, output_size, input_size,
+ input_size_coeff, 1,
+ torch.nn.init.normal_)
+ # Target.
+ set_random_seed(seed)
+ master_weight = torch.empty(output_size, input_size)
+ torch.nn.init.normal_(master_weight)
+ rank = mpu.get_model_parallel_rank()
+ my_weight = torch.split(master_weight, input_size_coeff,
+ dim=1)[rank].contiguous().clone()
+
+ # Compare.
+ error = weight.sub(my_weight).abs().max()
+ torch.distributed.barrier()
+ print(' row parallel max error (should be zero) on global rank '
+ '{}: {}'.format(torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print(' >> passed the test :-)')
+
+
+class IdentityLayer2D(torch.nn.Module):
+ def __init__(self, m, n):
+ super(IdentityLayer2D, self).__init__()
+ self.weight = Parameter(torch.Tensor(m, n))
+ torch.nn.init.xavier_normal_(self.weight)
+
+ def forward(self):
+ return self.weight
+
+
+def test_column_parallel_linear(model_parallel_size):
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ if torch.distributed.get_rank() == 0:
+ print('> testing ColumnParallelLinear with model parallel '
+ 'size: {}'.format(model_parallel_size))
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ seed = 12345
+ set_random_seed(seed)
+ input_size_coeff = 13
+ input_size = input_size_coeff * model_parallel_size
+ output_size_coeff = 17
+ output_size = output_size_coeff * model_parallel_size
+ batch_size = 7
+
+ # Network
+ identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
+ linear_layer = mpu.ColumnParallelLinear(
+ input_size, output_size, keep_master_weight_for_test=True).cuda()
+ loss_weight = torch.randn([batch_size, output_size]).cuda()
+ # Forward
+ input_ = identity_layer()
+ output = linear_layer(input_)
+ loss = torch.mul(output, loss_weight).sum()
+ # Backward
+ loss.backward()
+
+ # Values.
+ dLdY = loss_weight
+ X = identity_layer.weight
+ A = linear_layer.master_weight.cuda()
+ dLdA = torch.matmul(dLdY.t(), X)
+ dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
+ dLdX = torch.matmul(dLdY, A)
+
+ rank = mpu.get_model_parallel_rank()
+ my_dLdA = torch.split(dLdA, output_size_coeff,
+ dim=0)[rank].contiguous().clone()
+ error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
+ torch.distributed.barrier()
+ print(' error in dLdA on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ my_dLdb = torch.split(dLdb, output_size_coeff,
+ dim=0)[rank].contiguous().clone()
+ error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
+ torch.distributed.barrier()
+ print(' error in dLdb on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ error = dLdX.sub(identity_layer.weight.grad).abs().max()
+ torch.distributed.barrier()
+ print(' error in dLdX on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print(' >> passed the test :-)')
+
+
+def test_row_parallel_linear(model_parallel_size):
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ if torch.distributed.get_rank() == 0:
+ print('> testing RowParallelLinear with model parallel '
+ 'size: {}'.format(model_parallel_size))
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ seed = 12345
+ set_random_seed(seed)
+ input_size_coeff = 13
+ input_size = input_size_coeff * model_parallel_size
+ output_size_coeff = 17
+ output_size = output_size_coeff * model_parallel_size
+ batch_size = 7
+
+ # Network
+ identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
+ linear_layer = mpu.RowParallelLinear(
+ input_size, output_size, keep_master_weight_for_test=True).cuda()
+ loss_weight = torch.randn([batch_size, output_size]).cuda()
+ # Forward
+ input_ = identity_layer()
+ output = linear_layer(input_)
+ loss = torch.mul(output, loss_weight).sum()
+ # Backward
+ loss.backward()
+
+ # Values.
+ dLdY = loss_weight
+ X = identity_layer.weight
+ A = linear_layer.master_weight.cuda()
+ dLdA = torch.matmul(dLdY.t(), X)
+ dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
+ dLdX = torch.matmul(dLdY, A)
+
+ rank = mpu.get_model_parallel_rank()
+ my_dLdA = torch.split(dLdA, input_size_coeff,
+ dim=1)[rank].contiguous().clone()
+ error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
+ torch.distributed.barrier()
+ print(' error in dLdA on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ error = dLdb.sub(linear_layer.bias.grad).abs().max()
+ torch.distributed.barrier()
+ print(' error in dLdb on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ error = dLdX.sub(identity_layer.weight.grad).abs().max()
+ torch.distributed.barrier()
+ print(' error in dLdX on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print(' >> passed the test :-)')
+
+
+class IdentityLayer3D(torch.nn.Module):
+ def __init__(self, m, n, k):
+ super(IdentityLayer3D, self).__init__()
+ self.weight = Parameter(torch.Tensor(m, n, k))
+ torch.nn.init.xavier_normal_(self.weight)
+
+ def forward(self):
+ return self.weight
+
+
+def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
+ hidden_size_per_att_head, dropout_prob, batch_size,
+ sequence_length):
+ mpu.initialize_model_parallel(model_parallel_size)
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ seed = 12345
+ set_random_seed(seed)
+
+ num_att_heads = num_att_heads_per_partition * \
+ torch.distributed.get_world_size()
+ hidden_size = hidden_size_per_att_head * num_att_heads
+
+ # Network
+ identity_layer = IdentityLayer3D(batch_size, sequence_length,
+ hidden_size).cuda()
+ attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
+ dropout_prob).cuda()
+ loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
+ attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
+ # Forward
+ input_ = identity_layer()
+ output = attention_layer(input_, attention_mask)
+ loss = torch.mul(output, loss_weight).sum()
+ # Backward
+ loss.backward()
+
+ rank = mpu.get_model_parallel_rank()
+ mpu.destroy_model_parallel()
+ return rank, hidden_size, model_parallel_size, loss, \
+ attention_layer, identity_layer
+
+
+def test_parallel_self_attention(model_parallel_size):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing ParallelSelfAttention with model parallel '
+ 'size: {}'.format(model_parallel_size))
+
+ num_att_heads_per_partition = 3
+ hidden_size_per_att_head = 7
+ dropout_prob = 0.0 # has to be zero
+ batch_size = 5
+ sequence_length = 13
+
+ rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
+ attention_layer_1, identity_layer_1 = parallel_self_attention(
+ 1, num_att_heads_per_partition,
+ hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
+
+ rank, hidden_size, model_parallel_size, loss, \
+ attention_layer, identity_layer = parallel_self_attention(
+ model_parallel_size, num_att_heads_per_partition,
+ hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
+ assert hideen_size_1 == hidden_size
+
+ error = loss_1.sub(loss).abs().max()
+ torch.distributed.barrier()
+ print(' loss error on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 5.0e-6
+
+ my_lin_grad_list = torch.split(
+ attention_layer_1.query_key_value.weight.grad,
+ hidden_size // model_parallel_size, 0)[rank::model_parallel_size]
+ my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
+ error = my_lin_grad.sub(
+ attention_layer.query_key_value.weight.grad).abs().max()
+ torch.distributed.barrier()
+ print(' weight gradient error on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 5.0e-6
+
+ error = identity_layer_1.weight.grad.sub(
+ identity_layer.weight.grad).abs().max()
+ torch.distributed.barrier()
+ print(' input gradient error on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 5.0e-6
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print(' >> passed the test :-)')
+
+
+def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
+ hidden_size_per_att_head, batch_size, sequence_length):
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ seed = 12345
+ set_random_seed(seed)
+
+ num_att_heads = num_att_heads_per_partition * \
+ torch.distributed.get_world_size()
+ hidden_size = hidden_size_per_att_head * num_att_heads
+ intermediate_size = 4 * hidden_size
+
+ # Network
+ identity_layer = IdentityLayer3D(batch_size, sequence_length,
+ hidden_size).cuda()
+ transformer_layer = mpu.BertParallelTransformerLayer(
+ hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
+ torch.nn.functional.relu, 1.0e-5).cuda()
+
+ loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
+ attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
+ # Forward
+ input_ = identity_layer()
+ output = transformer_layer(input_, attention_mask)
+ loss = torch.mul(output, loss_weight).sum()
+ # Backward
+ loss.backward()
+
+ rank = mpu.get_model_parallel_rank()
+ mpu.destroy_model_parallel()
+ return rank, hidden_size, model_parallel_size, loss, \
+ transformer_layer, identity_layer
+
+
+def test_parallel_transformer_layer(model_parallel_size):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing ParallelTransformerLayer with model parallel '
+ 'size: {}'.format(model_parallel_size))
+
+ num_att_heads_per_partition = 3
+ hidden_size_per_att_head = 7
+ batch_size = 5
+ sequence_length = 13
+
+ rank_1, hidden_size_1, model_parallel_size_1, loss_1, \
+ transformer_layer_1, identity_layer_1 = parallel_transformer(
+ 1, num_att_heads_per_partition,
+ hidden_size_per_att_head, batch_size, sequence_length)
+
+ rank, hidden_size, model_parallel_size, loss, \
+ transformer_layer, identity_layer = parallel_transformer(
+ model_parallel_size, num_att_heads_per_partition,
+ hidden_size_per_att_head, batch_size, sequence_length)
+
+ error = loss_1.sub(loss).abs().max()
+ torch.distributed.barrier()
+ print(' loss error on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 5.0e-5, 'error: {}'.format(error)
+
+ error = identity_layer_1.weight.grad.sub(
+ identity_layer.weight.grad).abs().max()
+ torch.distributed.barrier()
+ print(' input gradient error on global rank {}: {}'.format(
+ torch.distributed.get_rank(), error))
+ assert error < 5.0e-5, 'error: {}'.format(error)
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print(' >> passed the test :-)')
+
+
+if __name__ == '__main__':
+
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ initialize_distributed()
+ world_size = torch.distributed.get_world_size()
+
+ print_separator('test initialize affine weight')
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ test_initialize_affine_weight(model_parallel_size)
+ model_parallel_size *= 2
+
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ print_separator('test parallel embedding')
+ test_parallel_embedding(model_parallel_size)
+ model_parallel_size *= 2
+
+ print_separator('test column-parallel linear')
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ test_column_parallel_linear(model_parallel_size)
+ model_parallel_size *= 2
+
+ print_separator('test row-parallel linear')
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ test_row_parallel_linear(model_parallel_size)
+ model_parallel_size *= 2
+
+ print_separator('test parallel self-attention')
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ test_parallel_self_attention(model_parallel_size)
+ model_parallel_size *= 2
+
+ print_separator('test parallel transformer')
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ test_parallel_transformer_layer(model_parallel_size)
+ model_parallel_size *= 2
diff --git a/megatron_lm/megatron/mpu/tests/test_random.py b/megatron_lm/megatron/mpu/tests/test_random.py
new file mode 100644
index 0000000..3ce7f8e
--- /dev/null
+++ b/megatron_lm/megatron/mpu/tests/test_random.py
@@ -0,0 +1,204 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from commons import print_separator
+from commons import initialize_distributed
+import mpu
+import torch
+import sys
+sys.path.append("../..")
+
+
+def test_set_cuda_rng_state(model_parallel_size):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing set_rng_state with size {} ...'.
+ format(model_parallel_size))
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ size = 123
+ seed = 1234
+ torch.cuda.manual_seed(1234)
+ tensor = torch.cuda.FloatTensor(size)
+
+ # Get the state
+ rng_state = torch.cuda.get_rng_state()
+ rng_state_copy = rng_state.clone()
+
+ # Do some stuff.
+ for _ in range(5):
+ torch.randn(size, out=tensor)
+ result_1 = tensor.clone()
+
+ assert rng_state.sub(rng_state_copy).max() == 0
+ assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
+
+ # State should be different.
+ new_rng_state = torch.cuda.get_rng_state()
+ max_diff = new_rng_state.sub(rng_state).max()
+ print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
+ format(torch.distributed.get_rank(), max_diff))
+ assert max_diff > 0
+
+ # Reset the rng state and do the same stuff.
+ mpu.random._set_cuda_rng_state(rng_state)
+ for _ in range(5):
+ torch.randn(size, out=tensor)
+ mpu.random._set_cuda_rng_state(rng_state)
+ for _ in range(5):
+ torch.randn(size, out=tensor)
+ result_2 = tensor.clone()
+
+ # Results should be the same
+ error = result_2.sub(result_1).abs().max()
+ print(' max error in generated tensors (should be zero) on '
+ 'global rank {}: {}'.format(torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ # Input state should have remained intact.
+ error = rng_state.sub(rng_state_copy).max()
+ print(' max error in rng state (should be zero) on global rank {}: {}'.
+ format(torch.distributed.get_rank(), error))
+ assert error == 0
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print('>> passed the test :-)')
+
+
+def test_cuda_rng_tracker(model_parallel_size):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing cuda rng tracker with size {} ...'.
+ format(model_parallel_size))
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ seed_1 = 1234
+ seed_2 = 4321
+ size = [12, 21]
+ tensor = torch.cuda.FloatTensor(size)
+
+ # Set to seed_1 and generate two tensors.
+ torch.cuda.manual_seed(seed_1)
+ torch.randn(size, out=tensor)
+ target_11 = tensor.clone()
+ torch.randn(size, out=tensor)
+ target_12 = tensor.clone()
+
+ # Set to seed_2 and generate two tensors.
+ torch.cuda.manual_seed(seed_2)
+ torch.randn(size, out=tensor)
+ target_21 = tensor.clone()
+ torch.randn(size, out=tensor)
+ target_22 = tensor.clone()
+
+ # Now if we interleave seed_1 and seed_2,
+ # we should still get the same tensors
+ torch.cuda.manual_seed(seed_1)
+ mpu.get_cuda_rng_tracker().add('test', seed_2)
+
+ torch.randn(size, out=tensor)
+ result_11 = tensor.clone()
+
+ with mpu.get_cuda_rng_tracker().fork('test'):
+ torch.randn(size, out=tensor)
+ result_21 = tensor.clone()
+
+ torch.randn(size, out=tensor)
+ result_12 = tensor.clone()
+
+ with mpu.get_cuda_rng_tracker().fork('test'):
+ torch.randn(size, out=tensor)
+ result_22 = tensor.clone()
+
+ diff = result_11.sub(result_21).abs().max()
+ diff = min(diff, result_12.sub(result_22).abs().max())
+ print(' max diff in generated tensors (should be non-zero) on '
+ 'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
+ assert diff > 1.0e-6
+ error = max(result_11.sub(target_11).abs().max(),
+ result_12.sub(target_12).abs().max())
+ error = max(error, result_21.sub(target_21).abs().max())
+ error = max(error, result_22.sub(target_22).abs().max())
+ print(' max error in generated tensors (should be zero) on '
+ 'global rank {}: {}'.format(torch.distributed.get_rank(), error))
+ assert error < 1.0e-6
+
+ # Reset the tracker
+ mpu.get_cuda_rng_tracker().reset()
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print('>> passed the test :-)')
+
+
+def test_model_parallel_cuda_manual_seed(model_parallel_size):
+
+ if torch.distributed.get_rank() == 0:
+ print('> testing model parallel cuda manual seed with size {} ...'.
+ format(model_parallel_size))
+
+ mpu.initialize_model_parallel(model_parallel_size)
+ model_parallel_size = mpu.get_model_parallel_world_size()
+
+ mpu.model_parallel_cuda_manual_seed(12345)
+ assert torch.cuda.initial_seed() == 12345
+ with mpu.get_cuda_rng_tracker().fork():
+ assert torch.cuda.initial_seed() == (12345 + 2718 +
+ mpu.get_model_parallel_rank())
+
+ # Reset the tracker
+ mpu.get_cuda_rng_tracker().reset()
+
+ # Reset groups
+ mpu.destroy_model_parallel()
+
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print('>> passed the test :-)')
+
+
+if __name__ == '__main__':
+
+ initialize_distributed()
+ world_size = torch.distributed.get_world_size()
+
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ print_separator('test set rng state')
+ test_set_cuda_rng_state(model_parallel_size)
+ model_parallel_size *= 2
+
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ print_separator('test cuda rng tracker')
+ test_cuda_rng_tracker(model_parallel_size)
+ model_parallel_size *= 2
+
+ model_parallel_size = 1
+ while model_parallel_size <= world_size:
+ print_separator('test model parallel cuda manual seed')
+ test_model_parallel_cuda_manual_seed(model_parallel_size)
+ model_parallel_size *= 2
diff --git a/megatron_lm/megatron/mpu/utils.py b/megatron_lm/megatron/mpu/utils.py
new file mode 100644
index 0000000..56ed1c7
--- /dev/null
+++ b/megatron_lm/megatron/mpu/utils.py
@@ -0,0 +1,70 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+
+def ensure_divisibility(numerator, denominator):
+ """Ensure that numerator is divisible by the denominator."""
+ assert numerator % denominator == 0, '{} is not divisible by {}'.format(
+ numerator, denominator)
+
+
+def divide(numerator, denominator):
+ """Ensure that numerator is divisible by the denominator and return
+ the division value."""
+ ensure_divisibility(numerator, denominator)
+ return numerator // denominator
+
+
+def split_tensor_along_last_dim(tensor, num_partitions,
+ contiguous_split_chunks=False):
+ """Split a tensor along its last dimension.
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = divide(tensor.size()[last_dim], num_partitions)
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+
+class VocabUtility:
+ """Split the vocabulary into `world_size` chunks amd return the
+ first and last index of the vocabulary belonging to the `rank`
+ partition: Note that indecies in [fist, last)"""
+
+ @staticmethod
+ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
+ rank, world_size):
+ index_f = rank * per_partition_vocab_size
+ index_l = index_f + per_partition_vocab_size
+ return index_f, index_l
+
+ @staticmethod
+ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
+ return VocabUtility.vocab_range_from_per_partition_vocab_size(
+ per_partition_vocab_size, rank, world_size)
diff --git a/megatron_lm/megatron/package_info.py b/megatron_lm/megatron/package_info.py
new file mode 100644
index 0000000..bd5decd
--- /dev/null
+++ b/megatron_lm/megatron/package_info.py
@@ -0,0 +1,30 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MAJOR = 1
+MINOR = 1.5
+
+# Use the following formatting: (major, minor)
+VERSION = (MAJOR, MINOR)
+
+__version__ = '.'.join(map(str, VERSION))
+__package_name__ = 'megatron-lm'
+__contact_names__ = 'NVIDIA INC'
+__url__ = 'https://github.com/NVIDIA/Megatron-LM'
+__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases'
+__description__ = 'Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.'
+__license__ = 'See https://github.com/NVIDIA/Megatron-LM/blob/master/LICENSE'
+__keywords__ = 'deep learning, Megatron, gpu, NLP, nvidia, pytorch, torch, language'
+
diff --git a/megatron_lm/megatron/text_generation_utils.py b/megatron_lm/megatron/text_generation_utils.py
new file mode 100644
index 0000000..6721f46
--- /dev/null
+++ b/megatron_lm/megatron/text_generation_utils.py
@@ -0,0 +1,397 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for generating text."""
+
+import copy
+import json
+import os
+import time
+
+import torch
+import torch.nn.functional as F
+
+from megatron import get_args
+from megatron import get_tokenizer
+from megatron import mpu
+from megatron.utils import get_ltor_masks_and_position_ids
+
+
+def get_batch(context_tokens):
+ """Generate batch from context tokens."""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ # Move to GPU.
+ tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
+ # Get the attention mask and postition ids.
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
+ tokens,
+ tokenizer.eod,
+ args.reset_position_ids,
+ args.reset_attention_mask,
+ args.eod_mask_loss)
+
+ return tokens, attention_mask, position_ids
+
+
+def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
+ """ This function has been mostly taken from huggingface conversational
+ ai code at
+ https://medium.com/huggingface/how-to-build-a-state-of-the-art-
+ conversational-ai-with-transfer-learning-2d818ac26313 """
+
+ if top_k > 0:
+ # Remove all tokens with a probability less than the
+ # last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p > 0.0:
+ # Cconvert to 1D
+ sorted_logits, sorted_indices = torch.sort(
+ logits, descending=True, dim=-1)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
+ dim=-1)
+
+ # Remove tokens with cumulative probability above the threshold
+ sorted_indices_to_remove = cumulative_probs > top_p
+ # Shift the indices to the right to keep also the first token
+ # above the threshold
+ sorted_indices_to_remove[..., 1:] \
+ = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+ for i in range(sorted_indices.size(0)):
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
+ logits[i][indices_to_remove] = filter_value
+
+ return logits
+
+
+def generate_samples_input_from_file(model):
+
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ # Read the sample file and open the output file.
+ assert args.sample_input_file is not None, \
+ 'sample input file is not provided.'
+ input_size = torch.cuda.LongTensor([0])
+ if mpu.get_model_parallel_rank() == 0:
+ with open(args.sample_input_file, "r") as fin:
+ inputs = [json.loads(line) for line in fin]
+ input_size[0] = len(inputs)
+ if args.sample_output_file is None:
+ sample_output_file = args.sample_input_file + ".out"
+ print('could not find `sample-output-file`, setting '
+ 'it to {}'.format(sample_output_file))
+ else:
+ sample_output_file = args.sample_output_file
+ output_list = []
+ torch.distributed.broadcast(
+ input_size, 0,
+ group=mpu.get_model_parallel_group()
+ )
+ input_size = input_size[0].item()
+
+ model.eval()
+
+ with torch.no_grad():
+ for i in range(input_size):
+
+ torch.distributed.barrier(group=mpu.get_model_parallel_group())
+
+ if mpu.get_model_parallel_rank() == 0:
+ raw_text = inputs[i][args.sample_context_field]
+ context_tokens = tokenizer.tokenize(raw_text)
+ context_length = len(context_tokens)
+ if context_length >= (args.seq_length // 2):
+ print("\nContext length", context_length,
+ "\nPlease give smaller context (half of the "
+ "sequence length)!", flush=True)
+ continue
+ else:
+ context_tokens = tokenizer.tokenize("EMPTY TEXT")
+ context_length = len(context_tokens)
+
+ token_stream = get_token_stream(model, [context_tokens])
+ for _, decode_tokens in enumerate(token_stream):
+ decode_tokens, _ = decode_tokens
+ decode_tokens = decode_tokens[0].cpu().numpy().tolist()
+
+ if mpu.get_model_parallel_rank() == 0:
+ os.system('clear')
+ print("\nContext:", raw_text, flush=True)
+ trim_decode_tokens = tokenizer.detokenize(
+ decode_tokens)[len(raw_text):]
+ print("\nMegatron-LM:", trim_decode_tokens, flush=True)
+ output_list.append({**inputs[i], args.sample_generated_field: trim_decode_tokens})
+
+ torch.distributed.barrier(group=mpu.get_model_parallel_group())
+ if mpu.get_model_parallel_rank() == 0:
+ with open(args.sample_output_file, 'w') as fout:
+ for row in output_list:
+ fout.write(json.dumps(row, ensure_ascii=False))
+ fout.write("\n")
+
+
+def generate_samples_interactive(model, print_frequency=24):
+
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ context_count = 0
+ model.eval()
+ with torch.no_grad():
+ while True:
+ torch.distributed.barrier(group=mpu.get_model_parallel_group())
+ terminate_runs = 0
+
+ if mpu.get_model_parallel_rank() == 0:
+ os.system('clear')
+ raw_text = input("\nContext prompt (stop to exit) >>> ")
+ while not raw_text:
+ print('Prompt should not be empty!')
+ raw_text = input("\nContext prompt (stop to exit) >>> ")
+
+ if "stop" in raw_text:
+ terminate_runs = 1
+ else:
+ context_tokens = tokenizer.tokenize(raw_text)
+ context_length = len(context_tokens)
+
+ if context_length >= (args.seq_length // 2):
+ print("\nContext length", context_length,
+ "\nPlease give smaller context (half of the "
+ "sequence length)!", flush=True)
+ continue
+ else:
+ context_tokens = tokenizer.tokenize("EMPTY TEXT")
+ context_length = len(context_tokens)
+
+ terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
+ torch.distributed.broadcast(terminate_runs_tensor,
+ mpu.get_model_parallel_src_rank(),
+ group=mpu.get_model_parallel_group())
+ terminate_runs = terminate_runs_tensor[0].item()
+
+ if terminate_runs == 1:
+ return
+
+ token_stream = get_token_stream(model, [context_tokens])
+ for counter, decode_tokens in enumerate(token_stream):
+ decode_tokens, _ = decode_tokens
+ decode_tokens = decode_tokens[0].cpu().numpy().tolist()
+
+ if mpu.get_model_parallel_rank() == 0 and \
+ counter % print_frequency == 0:
+ os.system('clear')
+ print("\nContext:", raw_text, flush=True)
+ trim_decode_tokens = tokenizer.detokenize(
+ decode_tokens)[len(raw_text):]
+ print("\nMegatron-LM:", trim_decode_tokens, flush=True)
+
+ if mpu.get_model_parallel_rank() == 0:
+ os.system('clear')
+ print("\nContext:", raw_text, flush=True)
+ trim_decode_tokens = tokenizer.detokenize(
+ decode_tokens)[len(raw_text):]
+ print("\nMegatron-LM:", trim_decode_tokens, flush=True)
+
+ raw_text = None
+ torch.distributed.barrier(group=mpu.get_model_parallel_group())
+ context_count += 1
+
+ if mpu.get_model_parallel_rank() == 0:
+ input("\nPress any key to continue >>>")
+
+
+def generate_samples_unconditional(model):
+
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ num_samples = args.num_samples
+ context_tokens = [[tokenizer.BOS]
+ for _ in range(args.batch_size)]
+ ctr = 0
+ while True:
+ start_time = time.time()
+ for token_stream in get_token_stream(model,
+ copy.deepcopy(context_tokens)):
+ pass
+ if ctr % args.log_interval == 0:
+ print('Avg s/batch:',
+ (time.time() - start_time) / min(args.log_interval, ctr + 1))
+ start_time = time.time()
+ length = len(token_stream)
+ token_batch = token_stream[0].cpu().numpy().tolist()
+ length_batch = token_stream[1].cpu().numpy().tolist()
+ for tokens, length in zip(token_batch, length_batch):
+ tokens = tokens[1:length - 1]
+ text = tokenizer.detokenize(tokens)
+ is_finished = length < args.seq_length - 1
+ datum = {'text': text, 'length': length - 1, 'finished': is_finished}
+ yield datum
+ ctr += 1
+ if ctr >= num_samples:
+ break
+ if ctr >= num_samples:
+ break
+
+
+def generate_and_write_samples_unconditional(model):
+
+ args = get_args()
+ assert args.genfile is not None
+ with open(args.genfile, 'w') as f:
+ for datum in generate_samples_unconditional(model):
+ f.write(json.dumps(datum, ensure_ascii=False) + '\n')
+
+
+def pad_batch(batch, pad_id, args):
+
+ context_lengths = []
+ for tokens in batch:
+ context_length = len(tokens)
+ if context_length < args.seq_length:
+ tokens.extend([pad_id] * (args.seq_length - context_length))
+ context_lengths.append(context_length)
+ return batch, context_lengths
+
+
+def get_token_stream(model, context_tokens):
+
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ context_tokens, context_lengths = pad_batch(context_tokens,
+ tokenizer.eod, args)
+
+ context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+ context_length_tensor = torch.cuda.LongTensor(context_lengths)
+
+ torch.distributed.broadcast(context_length_tensor,
+ mpu.get_model_parallel_src_rank(),
+ group=mpu.get_model_parallel_group())
+ torch.distributed.broadcast(context_tokens_tensor,
+ mpu.get_model_parallel_src_rank(),
+ group=mpu.get_model_parallel_group())
+
+ context_length = context_length_tensor.min().item()
+ tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
+
+ batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
+ context_length_tensor,
+ attention_mask, position_ids)
+ for tokens, lengths in batch_token_iterator:
+ context_length += 1
+ yield tokens[:, :context_length], lengths
+
+
+def switch(val1, val2, boolean):
+
+ boolean = boolean.type_as(val1)
+ return (1 - boolean) * val1 + boolean * val2
+
+
+def sample_sequence_batch(model, context_tokens, context_lengths,
+ attention_mask, position_ids,
+ maxlen=None, type_ids=None):
+
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ model.eval()
+ with torch.no_grad():
+ context_length = context_lengths.min().item()
+ eos_id = tokenizer.eod
+
+ counter = 0
+ org_context_length = context_length
+
+ layer_past = None
+ batch_size = context_tokens.size(0)
+ is_done = torch.zeros([batch_size]).byte().cuda()
+ tokens = context_tokens
+ if maxlen is None:
+ maxlen = args.seq_length - 1
+ if maxlen > (org_context_length + args.out_seq_length):
+ maxlen = org_context_length + args.out_seq_length - 1
+
+ lengths = torch.ones([batch_size]).long().cuda() * maxlen
+
+ while context_length <= (maxlen):
+
+ if args.recompute:
+ logits = model(tokens,
+ position_ids,
+ attention_mask,
+ tokentype_ids=type_ids,
+ forward_method_parallel_output=False)
+ logits = logits[:, context_length - 1, :]
+ else:
+ types2use = None
+ if counter == 0:
+ tokens2use = tokens[:, :context_length]
+ positions2use = position_ids[:, :context_length]
+ if type_ids is not None:
+ types2use = type_ids[:, :context_length]
+ else:
+ tokens2use = tokens[:, context_length - 1].view(
+ batch_size, -1)
+ positions2use = position_ids[:, context_length - 1].view(
+ batch_size, -1)
+ if type_ids is not None:
+ types2use = type_ids[:, context_length - 1].view(
+ batch_size, -1)
+ logits, layer_past = model(tokens2use,
+ positions2use,
+ attention_mask,
+ layer_past=layer_past,
+ get_key_value=True,
+ tokentype_ids=types2use,
+ forward_method_parallel_output=False)
+ logits = logits[:, -1].view(batch_size, -1).contiguous()
+
+ if args.greedy:
+ prev = torch.argmax(logits, dim=-1).view(-1)
+ else:
+ logits = logits.float()
+ logits /= args.temperature
+ logits = top_k_logits(logits, top_k=args.top_k,
+ top_p=args.top_p)
+ log_probs = F.softmax(logits, dim=-1)
+ prev = torch.multinomial(log_probs, num_samples=1).view(-1)
+
+ print_logits = []
+ for p in prev:
+ print_logits.append([logits[i, p].item()
+ for i in range(batch_size)])
+ started = context_lengths <= context_length
+ tokens[:, context_length] = switch(
+ tokens[:, context_length].view(-1), prev, started)
+ context_length += 1
+ counter += 1
+
+ done_token = (prev == eos_id).byte() & started.byte()
+ just_finished = (done_token & ~is_done).bool()
+ lengths[just_finished.view(-1)] = context_length
+ is_done = is_done | done_token
+ done = torch.all(is_done)
+
+ yield tokens, lengths
+ if done:
+ break
diff --git a/megatron_lm/megatron/tokenizer/__init__.py b/megatron_lm/megatron/tokenizer/__init__.py
new file mode 100644
index 0000000..311f2fd
--- /dev/null
+++ b/megatron_lm/megatron/tokenizer/__init__.py
@@ -0,0 +1,17 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .tokenizer import build_tokenizer
diff --git a/megatron_lm/megatron/tokenizer/bert_tokenization.py b/megatron_lm/megatron/tokenizer/bert_tokenization.py
new file mode 100644
index 0000000..a3aa6d9
--- /dev/null
+++ b/megatron_lm/megatron/tokenizer/bert_tokenization.py
@@ -0,0 +1,402 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import re
+import unicodedata
+import six
+
+
+def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
+ """Checks whether the casing config is consistent with the checkpoint name."""
+
+ # The casing has to be passed in by the user and there is no explicit check
+ # as to whether it matches the checkpoint. The casing information probably
+ # should have been stored in the bert_config.json file, but it's not, so
+ # we have to heuristically detect it to validate.
+
+ if not init_checkpoint:
+ return
+
+ m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
+ if m is None:
+ return
+
+ model_name = m.group(1)
+
+ lower_models = [
+ "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
+ "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
+ ]
+
+ cased_models = [
+ "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
+ "multi_cased_L-12_H-768_A-12"
+ ]
+
+ is_bad_config = False
+ if model_name in lower_models and not do_lower_case:
+ is_bad_config = True
+ actual_flag = "False"
+ case_name = "lowercased"
+ opposite_flag = "True"
+
+ if model_name in cased_models and do_lower_case:
+ is_bad_config = True
+ actual_flag = "True"
+ case_name = "cased"
+ opposite_flag = "False"
+
+ if is_bad_config:
+ raise ValueError(
+ "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
+ "However, `%s` seems to be a %s model, so you "
+ "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
+ "how the model was pre-training. If this error is wrong, please "
+ "just comment out this check." % (actual_flag, init_checkpoint,
+ model_name, case_name, opposite_flag))
+
+
+def convert_to_unicode(text):
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
+ if six.PY3:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, bytes):
+ return text.decode("utf-8", "ignore")
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ elif six.PY2:
+ if isinstance(text, str):
+ return text.decode("utf-8", "ignore")
+ elif isinstance(text, unicode):
+ return text
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ else:
+ raise ValueError("Not running on Python2 or Python 3?")
+
+
+def printable_text(text):
+ """Returns text encoded in a way suitable for print or `tf.logging`."""
+
+ # These functions want `str` for both Python2 and Python3, but in one case
+ # it's a Unicode string and in the other it's a byte string.
+ if six.PY3:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, bytes):
+ return text.decode("utf-8", "ignore")
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ elif six.PY2:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, unicode):
+ return text.encode("utf-8")
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ else:
+ raise ValueError("Not running on Python2 or Python 3?")
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ index = 0
+ with open(vocab_file, "r") as reader:
+ while True:
+ token = convert_to_unicode(reader.readline())
+ if not token:
+ break
+ token = token.strip()
+ vocab[token] = index
+ index += 1
+ return vocab
+
+
+def convert_by_vocab(vocab, items):
+ """Converts a sequence of [tokens|ids] using the vocab."""
+ output = []
+ for item in items:
+ output.append(vocab[item])
+ return output
+
+
+def convert_tokens_to_ids(vocab, tokens):
+ return convert_by_vocab(vocab, tokens)
+
+
+def convert_ids_to_tokens(inv_vocab, ids):
+ return convert_by_vocab(inv_vocab, ids)
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class FullTokenizer(object):
+ """Runs end-to-end tokenziation."""
+
+ def __init__(self, vocab_file, do_lower_case=True):
+ self.vocab = load_vocab(vocab_file)
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+
+ def tokenize(self, text):
+ split_tokens = []
+ for token in self.basic_tokenizer.tokenize(text):
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
+ split_tokens.append(sub_token)
+
+ return split_tokens
+
+ def convert_tokens_to_ids(self, tokens):
+ return convert_by_vocab(self.vocab, tokens)
+
+ def convert_ids_to_tokens(self, ids):
+ return convert_by_vocab(self.inv_vocab, ids)
+
+ def vocab_size(self):
+ return len(self.vocab)
+
+
+class BasicTokenizer(object):
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+
+ def __init__(self, do_lower_case=True):
+ """Constructs a BasicTokenizer.
+
+ Args:
+ do_lower_case: Whether to lower case the input.
+ """
+ self.do_lower_case = do_lower_case
+
+ def tokenize(self, text):
+ """Tokenizes a piece of text."""
+ text = convert_to_unicode(text)
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ text = self._tokenize_chinese_chars(text)
+
+ orig_tokens = whitespace_tokenize(text)
+ split_tokens = []
+ for token in orig_tokens:
+ if self.do_lower_case:
+ token = token.lower()
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text):
+ """Splits punctuation on a piece of text."""
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xfffd or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+class WordpieceTokenizer(object):
+ """Runs WordPiece tokenziation."""
+
+ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """Tokenizes a piece of text into its word pieces.
+
+ This uses a greedy longest-match-first algorithm to perform tokenization
+ using the given vocabulary.
+
+ For example:
+ input = "unaffable"
+ output = ["un", "##aff", "##able"]
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through `BasicTokenizer.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ text = convert_to_unicode(text)
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+def _is_whitespace(char):
+ """Checks whether `chars` is a whitespace character."""
+ # \t, \n, and \r are technically contorl characters but we treat them
+ # as whitespace since they are generally considered as such.
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
+ return True
+ cat = unicodedata.category(char)
+ if cat == "Zs":
+ return True
+ return False
+
+
+def _is_control(char):
+ """Checks whether `chars` is a control character."""
+ # These are technically control characters but we count them as whitespace
+ # characters.
+ if char == "\t" or char == "\n" or char == "\r":
+ return False
+ cat = unicodedata.category(char)
+ if cat in ("Cc", "Cf"):
+ return True
+ return False
+
+
+def _is_punctuation(char):
+ """Checks whether `chars` is a punctuation character."""
+ cp = ord(char)
+ # We treat all non-letter/number ASCII as punctuation.
+ # Characters such as "^", "$", and "`" are not in the Unicode
+ # Punctuation class but we treat them as punctuation anyways, for
+ # consistency.
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
+ return True
+ cat = unicodedata.category(char)
+ if cat.startswith("P"):
+ return True
+ return False
diff --git a/megatron_lm/megatron/tokenizer/gpt2_tokenization.py b/megatron_lm/megatron/tokenizer/gpt2_tokenization.py
new file mode 100644
index 0000000..3f37e44
--- /dev/null
+++ b/megatron_lm/megatron/tokenizer/gpt2_tokenization.py
@@ -0,0 +1,321 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization classes for OpenAI GPT."""
+
+from __future__ import (absolute_import, division, print_function,
+ unicode_literals)
+
+import sys
+import json
+import logging
+import os
+import regex as re
+from io import open
+
+try:
+ from functools import lru_cache
+except ImportError:
+ # Just a dummy decorator to get the checks to run on python2
+ # because honestly I don't want to support a byte-level unicode BPE
+ # tokenizer on python 2 right now.
+ def lru_cache():
+ return lambda func: func
+
+
+logger = logging.getLogger(__name__)
+
+PRETRAINED_VOCAB_ARCHIVE_MAP = {
+ 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
+}
+PRETRAINED_MERGES_ARCHIVE_MAP = {
+ 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
+}
+PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
+ 'gpt2': 1024,
+}
+VOCAB_NAME = 'vocab.json'
+MERGES_NAME = 'merges.txt'
+SPECIAL_TOKENS_NAME = 'special_tokens.txt'
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ _chr = unichr if sys.version_info[0] == 2 else chr
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
+ list(range(ord("®"), ord("ÿ") + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [_chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class GPT2Tokenizer(object):
+ """
+ GPT-2 BPE tokenizer. Peculiarities:
+ - Byte-level BPE
+ """
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
+ """
+ Instantiate a PreTrainedBertModel from a pre-trained model file.
+ Download and cache the pre-trained model file if needed.
+ """
+ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
+ vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
+ merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
+ special_tokens_file = None
+ else:
+ vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
+ merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
+ special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
+ if not os.path.exists(special_tokens_file):
+ special_tokens_file = None
+ else:
+ logger.info("loading special tokens file {}".format(special_tokens_file))
+ # redirect to the cache, if necessary
+ try:
+ from .file_utils import cached_path
+ resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
+ resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
+ except EnvironmentError:
+ logger.error(
+ "Model name '{}' was not found in model name list ({}). "
+ "We assumed '{}' was a path or url but couldn't find files {} and {} "
+ "at this path or url.".format(
+ pretrained_model_name_or_path,
+ ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
+ pretrained_model_name_or_path,
+ vocab_file, merges_file))
+ return None
+ if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
+ logger.info("loading vocabulary file {}".format(vocab_file))
+ logger.info("loading merges file {}".format(merges_file))
+ else:
+ logger.info("loading vocabulary file {} from cache at {}".format(
+ vocab_file, resolved_vocab_file))
+ logger.info("loading merges file {} from cache at {}".format(
+ merges_file, resolved_merges_file))
+ if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
+ # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
+ # than the number of positional embeddings
+ max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
+ kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
+ # Instantiate tokenizer.
+ if special_tokens_file and 'special_tokens' not in kwargs:
+ special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
+ else:
+ special_tokens = kwargs.pop('special_tokens', [])
+ tokenizer = cls(
+ resolved_vocab_file,
+ resolved_merges_file,
+ special_tokens=special_tokens,
+ *inputs,
+ **kwargs)
+ return tokenizer
+
+ def __init__(self, vocab_file, merges_file, errors='replace',
+ special_tokens=None, max_len=None):
+ self.max_len = max_len if max_len is not None else int(1e12)
+ self.encoder = json.load(open(vocab_file))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_data]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+
+ # Should haved added re.IGNORECASE so BPE merges can happen for
+ # capitalized versions of contractions
+ self.pat = re.compile(
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+ self.special_tokens = {}
+ self.special_tokens_decoder = {}
+ self.set_special_tokens(special_tokens)
+
+ def __len__(self):
+ return len(self.encoder) + len(self.special_tokens)
+
+ def set_special_tokens(self, special_tokens):
+ """ Add a list of additional tokens to the encoder.
+ The additional tokens are indexed starting from the last index of the
+ current vocabulary in the order of the `special_tokens` list.
+ """
+ if not special_tokens:
+ self.special_tokens = {}
+ self.special_tokens_decoder = {}
+ return
+ self.special_tokens = dict((tok, len(self.encoder) + i)
+ for i, tok in enumerate(special_tokens))
+ self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
+ logger.info("Special tokens {}".format(self.special_tokens))
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except BaseException:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def tokenize(self, text):
+ """ Tokenize a string. """
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ if sys.version_info[0] == 2:
+ token = ''.join(self.byte_encoder[ord(b)] for b in token)
+ else:
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def convert_tokens_to_ids(self, tokens):
+ """ Converts a sequence of tokens into ids using the vocab. """
+ ids = []
+ if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
+ if tokens in self.special_tokens:
+ return self.special_tokens[tokens]
+ else:
+ return self.encoder.get(tokens, 0)
+ for token in tokens:
+ if token in self.special_tokens:
+ ids.append(self.special_tokens[token])
+ else:
+ ids.append(self.encoder.get(token, 0))
+ if len(ids) > self.max_len:
+ logger.warning(
+ "Token indices sequence length is longer than the specified maximum "
+ " sequence length for this OpenAI GPT model ({} > {}). Running this"
+ " sequence through the model will result in indexing errors".format(
+ len(ids), self.max_len)
+ )
+ return ids
+
+ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
+ """Converts a sequence of ids in BPE tokens using the vocab."""
+ tokens = []
+ for i in ids:
+ if i in self.special_tokens_decoder:
+ if not skip_special_tokens:
+ tokens.append(self.special_tokens_decoder[i])
+ else:
+ tokens.append(self.decoder[i])
+ return tokens
+
+ def encode(self, text):
+ return self.convert_tokens_to_ids(self.tokenize(text))
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
+ return text
+
+ def save_vocabulary(self, vocab_path):
+ """Save the tokenizer vocabulary and merge files to a directory."""
+ if not os.path.isdir(vocab_path):
+ logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
+ return
+ vocab_file = os.path.join(vocab_path, VOCAB_NAME)
+ merge_file = os.path.join(vocab_path, MERGES_NAME)
+ special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
+
+ with open(vocab_file, 'w', encoding='utf-8') as f:
+ f.write(json.dumps(self.encoder, ensure_ascii=False))
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write(u'#version: 0.2\n')
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!".format(merge_file))
+ index = token_index
+ writer.write(' '.join(bpe_tokens) + u'\n')
+ index += 1
+
+ index = len(self.encoder)
+ with open(special_tokens_file, 'w', encoding='utf-8') as writer:
+ for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!".format(special_tokens_file))
+ index = token_index
+ writer.write(token + u'\n')
+ index += 1
+
+ return vocab_file, merge_file, special_tokens_file
diff --git a/megatron_lm/megatron/tokenizer/sp_tokenization.py b/megatron_lm/megatron/tokenizer/sp_tokenization.py
new file mode 100644
index 0000000..0df6746
--- /dev/null
+++ b/megatron_lm/megatron/tokenizer/sp_tokenization.py
@@ -0,0 +1,78 @@
+import six
+import sentencepiece as spm
+
+
+def convert_to_unicode(text):
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
+ return six.ensure_text(text, errors="ignore")
+
+
+class SentencePieceTokenizer:
+ NEW_LINE = "[NL]"
+ UNK = 0
+ BOS = 1
+ EOS = 2
+ BOS_TOKEN = "<s>"
+ EOS_TOKEN = "</s>"
+ MASK_TOKEN = "[MASK]"
+
+ def __init__(self, vocab_file):
+ self.name = "sp"
+ self._tokenizer = spm.SentencePieceProcessor(model_file=vocab_file)
+ self._vocab_words = self._get_vocab_words()
+ self.encoder = {token: idx for idx, token in enumerate(self._vocab_words)}
+ self.decoder = {idx: token for idx, token in enumerate(self._vocab_words)}
+
+ mask_tokens = self.convert_tokens_to_ids([self.MASK_TOKEN])
+ assert len(mask_tokens) == 1
+ self.MASK = mask_tokens[0]
+
+ def _encode(self, line, out_type=str):
+ return self._tokenizer.encode(line, out_type=out_type)
+
+ def tokenize(self, line, out_type=int):
+ line = convert_to_unicode(line)
+ line = line.replace("\n", SentencePieceTokenizer.NEW_LINE)
+ return self._encode(line, out_type=out_type) # BOS will be added in another wrapper
+
+ def convert_tokens_to_ids(self, tokens):
+ return self._tokenizer.piece_to_id(tokens)
+
+ def convert_ids_to_tokens(self, ids):
+ return [self.decoder[idx] for idx in ids]
+
+ def get_tokens(self):
+ return self._vocab_words
+
+ def _get_vocab_words(self):
+ indices = list(range(self._tokenizer.GetPieceSize()))
+ return self._tokenizer.id_to_piece(indices)
+
+ @property
+ def vocab(self):
+ return self.encoder
+
+ @property
+ def inv_vocab(self):
+ return self.decoder
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def detokenize(self, token_ids):
+ tokens = [self.decoder[idx] for idx in token_ids]
+ text = "".join(tokens).replace("\u2581", " ").replace(self.EOS_TOKEN, "").lstrip()
+ return text
+
+ @property
+ def cls(self):
+ return self.BOS
+
+ @property
+ def eod(self):
+ return self.EOS
+
+ @property
+ def mask(self):
+ return self.MASK
diff --git a/megatron_lm/megatron/tokenizer/tokenizer.py b/megatron_lm/megatron/tokenizer/tokenizer.py
new file mode 100644
index 0000000..9637723
--- /dev/null
+++ b/megatron_lm/megatron/tokenizer/tokenizer.py
@@ -0,0 +1,280 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron tokenizers."""
+
+from abc import ABC
+from abc import abstractmethod
+
+from .bert_tokenization import FullTokenizer as FullBertTokenizer
+from .gpt2_tokenization import GPT2Tokenizer
+from .sp_tokenization import SentencePieceTokenizer
+
+
+def build_tokenizer(args):
+ """Initialize tokenizer."""
+ if args.rank == 0:
+ print('> building {} tokenizer ...'.format(args.tokenizer_type),
+ flush=True)
+
+ # Select and instantiate the tokenizer.
+ assert args.vocab_file is not None
+ if args.tokenizer_type == 'SentencePiece':
+ tokenizer = _SpTokenizer(vocab_file=args.vocab_file)
+ elif args.tokenizer_type == 'BertWordPieceLowerCase':
+ tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
+ lower_case=True)
+ elif args.tokenizer_type == 'BertWordPieceCase':
+ tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
+ lower_case=False)
+ elif args.tokenizer_type == 'GPT2BPETokenizer':
+ assert args.merge_file is not None
+ tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
+ else:
+ raise NotImplementedError('{} tokenizer is not '
+ 'implemented.'.format(args.tokenizer_type))
+
+ # Add vocab size.
+ args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
+ args)
+
+ return tokenizer
+
+
+def _vocab_size_with_padding(orig_vocab_size, args):
+ """Pad vocab size so it is divisible by model parallel size and
+ still having GPU friendly size."""
+
+ after = orig_vocab_size
+ multiple = args.make_vocab_size_divisible_by * \
+ args.model_parallel_size
+ while (after % multiple) != 0:
+ after += 1
+ if args.rank == 0:
+ print(' > padded vocab (size: {}) with {} dummy tokens '
+ '(new size: {})'.format(
+ orig_vocab_size, after - orig_vocab_size, after), flush=True)
+ return after
+
+
+class AbstractTokenizer(ABC):
+ """Abstract class for tokenizer."""
+
+ def __init__(self, name):
+ self.name = name
+ super().__init__()
+
+ @property
+ @abstractmethod
+ def vocab_size(self):
+ pass
+
+ @property
+ @abstractmethod
+ def vocab(self):
+ """Dictionary from vocab text token to id token."""
+ pass
+
+ @property
+ @abstractmethod
+ def inv_vocab(self):
+ """Dictionary from vocab id token to text token."""
+ pass
+
+ @abstractmethod
+ def tokenize(self, text):
+ pass
+
+ def detokenize(self, token_ids):
+ raise NotImplementedError('detokenizer is not implemented for {} '
+ 'tokenizer'.format(self.name))
+
+ @property
+ def cls(self):
+ raise NotImplementedError('CLS is not provided for {} '
+ 'tokenizer'.format(self.name))
+
+ @property
+ def sep(self):
+ raise NotImplementedError('SEP is not provided for {} '
+ 'tokenizer'.format(self.name))
+
+ @property
+ def pad(self):
+ raise NotImplementedError('PAD is not provided for {} '
+ 'tokenizer'.format(self.name))
+
+ @property
+ def eod(self):
+ raise NotImplementedError('EOD is not provided for {} '
+ 'tokenizer'.format(self.name))
+
+ @property
+ def mask(self):
+ raise NotImplementedError('MASK is not provided for {} '
+ 'tokenizer'.format(self.name))
+
+
+class _BertWordPieceTokenizer(AbstractTokenizer):
+ """Original BERT wordpiece tokenizer."""
+
+ def __init__(self, vocab_file, lower_case=True):
+ if lower_case:
+ name = 'BERT Lower Case'
+ else:
+ name = 'BERT Upper Case'
+ super().__init__(name)
+ self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
+ self.cls_id = self.tokenizer.vocab['[CLS]']
+ self.sep_id = self.tokenizer.vocab['[SEP]']
+ self.pad_id = self.tokenizer.vocab['[PAD]']
+ self.mask_id = self.tokenizer.vocab['[MASK]']
+
+ @property
+ def vocab_size(self):
+ return self.tokenizer.vocab_size()
+
+ @property
+ def vocab(self):
+ return self.tokenizer.vocab
+
+ @property
+ def inv_vocab(self):
+ return self.tokenizer.inv_vocab
+
+ def tokenize(self, text):
+ text_tokens = self.tokenizer.tokenize(text)
+ return self.tokenizer.convert_tokens_to_ids(text_tokens)
+
+ def decode_token_ids(self, token_ids):
+ tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
+ exclude_list = ['[PAD]', '[CLS]']
+ non_pads = [t for t in tokens if t not in exclude_list]
+
+ result = ""
+ for s in non_pads:
+ if s.startswith("##"):
+ result += s[2:]
+ else:
+ result += " " + s
+
+ return result
+
+ @property
+ def cls(self):
+ return self.cls_id
+
+ @property
+ def sep(self):
+ return self.sep_id
+
+ @property
+ def pad(self):
+ return self.pad_id
+
+ @property
+ def mask(self):
+ return self.mask_id
+
+
+class _GPT2BPETokenizer(AbstractTokenizer):
+ """Original GPT2 BPE tokenizer."""
+
+ def __init__(self, vocab_file, merge_file):
+ name = 'GPT2 BPE'
+ super().__init__(name)
+
+ self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
+ special_tokens=[], max_len=None)
+ self.eod_id = self.tokenizer.encoder['<|endoftext|>']
+
+ @property
+ def vocab_size(self):
+ return len(self.tokenizer.encoder)
+
+ @property
+ def vocab(self):
+ return self.tokenizer.encoder
+
+ @property
+ def inv_vocab(self):
+ return self.tokenizer.decoder
+
+ def tokenize(self, text):
+ return self.tokenizer.encode(text)
+
+ def detokenize(self, token_ids):
+ return self.tokenizer.decode(token_ids)
+
+ @property
+ def eod(self):
+ return self.eod_id
+
+
+class _SpTokenizer(AbstractTokenizer):
+ NEW_LINE = "[NL]"
+ UNK = 0
+ BOS = 1
+ EOS = 2
+ BOS_TOKEN = "<s>"
+ EOS_TOKEN = "</s>"
+ MASK_TOKEN = "[MASK]"
+
+ def __init__(self, vocab_file, strip_eos=False):
+ name = "SP"
+ super().__init__(name)
+ self.tokenizer = SentencePieceTokenizer(vocab_file)
+ self.strip_eos = strip_eos
+
+ self.encoder = {}
+ self.decoder = {}
+ for id, token in enumerate(self.tokenizer.get_tokens()):
+ self.encoder[token] = id
+ self.decoder[id] = token
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ @property
+ def vocab(self):
+ return self.encoder
+
+ @property
+ def inv_vocab(self):
+ return self.decoder
+
+ def tokenize(self, text):
+ tokens = self.tokenizer.tokenize(text, out_type=str)
+ tokens = [self.BOS_TOKEN] + tokens
+ token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
+ return token_ids
+
+ def detokenize(self, token_ids):
+ return self.tokenizer._tokenizer.detokenize(token_ids)
+
+ @property
+ def cls(self):
+ return self.tokenizer.BOS
+
+ @property
+ def eod(self):
+ return self.tokenizer.EOS
+
+ @property
+ def mask(self):
+ tokens = self.tokenizer.convert_tokens_to_ids([self.MASK_TOKEN])
+ assert len(tokens) == 1
+ return tokens[0]
diff --git a/megatron_lm/megatron/training.py b/megatron_lm/megatron/training.py
new file mode 100644
index 0000000..54359fa
--- /dev/null
+++ b/megatron_lm/megatron/training.py
@@ -0,0 +1,685 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Pretrain utilities."""
+
+from datetime import datetime
+import math
+import sys
+import torch
+import json
+from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
+from apex.optimizers import FusedAdam as Adam
+
+from megatron import get_args
+from megatron import get_timers
+from megatron import get_tensorboard_writer
+from megatron import mpu
+from megatron import print_rank_0
+from megatron.checkpointing import load_checkpoint
+from megatron.checkpointing import save_checkpoint
+from megatron.fp16 import FP16_Module
+from megatron.fp16 import FP16_Optimizer
+from megatron.initialize import initialize_megatron
+from megatron.learning_rates import AnnealingLR
+from megatron.model import DistributedDataParallel as LocalDDP
+from megatron.model import get_params_for_weight_decay_optimization
+from megatron.model.realm_model import ICTBertModel
+from megatron.utils import check_adlr_autoresume_termination
+from megatron.utils import make_data_loader
+from megatron.utils import report_memory, flops_calculator
+
+import deepspeed
+from deepspeed.runtime.utils import see_memory_usage
+
+
+def pretrain(train_valid_test_dataset_provider, model_provider,
+ forward_step_func, extra_args_provider=None, args_defaults={}):
+ """Main training program.
+
+ This function will run the followings in the order provided:
+ 1) initialize Megatron.
+ 2) setup model, optimizer and lr schedule using the model_provider.
+ 3) call train_val_test_data_provider to get train/val/test datasets.
+ 4) train the modle using the forward_step_func.
+
+ Arguments:
+ train_valid_test_dataset_provider: a function that takes the size of
+ train/valid/test dataset and returns `train, valid, test` datasets.
+ model_provider: a function that returns a vanilla version of the
+ model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
+ forward_step_func: a function that takes a `data iterator` and `model`,
+ and returns a `loss` scalar with a dictionary with key:values being
+ the info we would like to monitor during training, for example
+ `lm-loss: value`. We also require that this function add
+ `batch generator` to the timers class.
+ extra_args_provider: a function that takes a parser and adds arguments
+ to it. It is used for programs to add their own arguments.
+ args_defaults: a dictionary from argument-name to argument-value. It
+ to set already parse arguments.
+ """
+
+ # Initalize and get arguments, timers, and Tensorboard writer.
+ initialize_megatron(extra_args_provider=extra_args_provider,
+ args_defaults=args_defaults)
+
+ args = get_args()
+ timers = get_timers()
+
+ args.curriculum_learning = False
+ if args.deepspeed:
+ args.deepspeed_configuration = json.load(
+ open(args.deepspeed_config, 'r', encoding='utf-8'))
+ if "curriculum_learning" in args.deepspeed_configuration:
+ if "enabled" in args.deepspeed_configuration["curriculum_learning"]:
+ args.curriculum_learning = args.deepspeed_configuration["curriculum_learning"]["enabled"]
+
+ # Model, optimizer, and learning rate.
+ timers('model and optimizer').start()
+ model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
+ timers('model and optimizer').stop()
+
+ # Data stuff.
+ timers('train/valid/test data iterators').start()
+ train_data_iterator, valid_data_iterator, test_data_iterator \
+ = build_train_valid_test_data_iterators(
+ train_valid_test_dataset_provider)
+ timers('train/valid/test data iterators').stop()
+
+ # Print setup timing.
+ print_rank_0('done with setups ...')
+ timers.log(['model and optimizer', 'train/valid/test data iterators'])
+ print_rank_0('training ...')
+
+ iteration = 0
+
+ # save_checkpoint(iteration, model, optimizer, lr_scheduler) # force save
+
+ if args.do_train and args.train_iters > 0:
+ iteration = train(forward_step_func,
+ model, optimizer, lr_scheduler,
+ train_data_iterator, valid_data_iterator)
+
+ if args.do_valid:
+ prefix = 'the end of training for val data'
+ evaluate_and_print_results(prefix, forward_step_func,
+ valid_data_iterator, model,
+ iteration, False)
+
+ if args.save and iteration != 0:
+ save_checkpoint(iteration, model, optimizer, lr_scheduler)
+
+ if args.do_test:
+ # Run on test data.
+ prefix = 'the end of training for test data'
+ evaluate_and_print_results(prefix, forward_step_func,
+ test_data_iterator, model,
+ 0, True)
+
+
+def get_model(model_provider_func):
+ """Build the model."""
+ args = get_args()
+
+ # Build model on cpu.
+ model = model_provider_func()
+
+ if args.deepspeed:
+ # DeepSpeed handles CUDA, FP16, and DDP components.
+ return model
+
+ # GPU allocation.
+ model.cuda(torch.cuda.current_device())
+
+ # Fp16 conversion.
+ if args.fp16:
+ model = FP16_Module(model)
+
+ # Wrap model for distributed training."""
+ if args.DDP_impl == 'torch':
+ i = torch.cuda.current_device()
+ model = torchDDP(model, device_ids=[i], output_device=i,
+ process_group=mpu.get_data_parallel_group())
+ return model
+ if args.DDP_impl == 'local':
+ model = LocalDDP(model)
+ return model
+
+ raise NotImplementedError('Unknown DDP implementation specified: {}. '
+ 'Exiting.'.format(args.DDP_impl))
+
+
+def get_optimizer(model):
+ """Set up the optimizer."""
+ args = get_args()
+
+ # Build parameter groups (weight decay and non-decay).
+ while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
+ model = model.module
+ param_groups = get_params_for_weight_decay_optimization(model)
+
+ # Add model parallel attribute if it is not set.
+ for param_group in param_groups:
+ for param in param_group['params']:
+ if not hasattr(param, 'model_parallel'):
+ param.model_parallel = False
+
+ if args.cpu_optimizer:
+ if args.cpu_torch_adam:
+ cpu_adam_optimizer = torch.optim.AdamW
+ else:
+ from deepspeed.ops.adam import DeepSpeedCPUAdam
+ cpu_adam_optimizer = DeepSpeedCPUAdam
+ optimizer = cpu_adam_optimizer(param_groups,
+ lr=args.lr,
+ weight_decay=args.weight_decay)
+ else:
+ # Use torch Adam instead of Fused Adam from NVIDIA which seems to have some issue.
+ #optimizer = Adam(param_groups,
+ optimizer = torch.optim.AdamW(param_groups,
+ lr=args.lr,
+ weight_decay=args.weight_decay,
+ betas=(args.adam_beta1, args.adam_beta2),
+ eps=args.adam_eps)
+
+ if args.deepspeed:
+ # fp16 wrapper is not required for DeepSpeed.
+ return optimizer
+
+ # Wrap into fp16 optimizer.
+ if args.fp16:
+ optimizer = FP16_Optimizer(optimizer,
+ static_loss_scale=args.loss_scale,
+ dynamic_loss_scale=args.dynamic_loss_scale,
+ dynamic_loss_args={
+ 'scale_window': args.loss_scale_window,
+ 'min_scale': args.min_scale,
+ 'delayed_shift': args.hysteresis})
+
+ return optimizer
+
+
+def get_learning_rate_scheduler(optimizer):
+ """Build the learning rate scheduler."""
+ args = get_args()
+
+ # Add linear learning rate scheduler.
+ if args.lr_decay_iters is not None:
+ num_iters = args.lr_decay_iters
+ else:
+ num_iters = args.train_iters
+ num_iters = max(1, num_iters)
+ init_step = 0
+ if args.warmup_iters is not None:
+ warmup_iter = args.warmup_iters
+ else:
+ warmup_iter = args.warmup * num_iters
+ lr_scheduler = AnnealingLR(
+ optimizer,
+ start_lr=args.lr,
+ warmup_iter=warmup_iter,
+ total_iters=num_iters,
+ decay_style=args.lr_decay_style,
+ last_iter=init_step,
+ min_lr=args.min_lr,
+ use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
+ override_lr_scheduler=args.override_lr_scheduler)
+
+ return lr_scheduler
+
+
+def setup_model_and_optimizer(model_provider_func):
+ """Setup model and optimizer."""
+ args = get_args()
+
+ model = get_model(model_provider_func)
+ optimizer = get_optimizer(model)
+ lr_scheduler = get_learning_rate_scheduler(optimizer)
+
+ if args.deepspeed:
+ print_rank_0("DeepSpeed is enabled.")
+
+ model, optimizer, _, lr_scheduler = deepspeed.initialize(
+ model=model,
+ optimizer=optimizer,
+ args=args,
+ lr_scheduler=lr_scheduler,
+ mpu=mpu,
+ dist_init_required=False)
+ if args.load is not None:
+ args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
+ else:
+ args.iteration = 0
+
+ # get model without FP16 and/or TorchDDP wrappers
+ unwrapped_model = model
+ while hasattr(unwrapped_model, 'module'):
+ unwrapped_model = unwrapped_model.module
+
+ if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
+ print("Initializing ICT from pretrained BERT model", flush=True)
+ unwrapped_model.init_state_dict_from_bert()
+
+ return model, optimizer, lr_scheduler
+
+
+def backward_step(optimizer, model, loss):
+ """Backward step."""
+ args = get_args()
+ timers = get_timers()
+
+ # Backward pass.
+ timers('backward-backward').start()
+ if args.deepspeed:
+ model.backward(loss)
+ else:
+ optimizer.zero_grad(set_grads_to_None=True)
+ if args.fp16:
+ optimizer.backward(loss, update_master_grads=False)
+ else:
+ loss.backward()
+ timers('backward-backward').stop()
+
+ if args.deepspeed:
+ # DeepSpeed backward propagation already addressed all reduce communication.
+ # Reset the timer to avoid breaking timer logs below.
+ timers('backward-allreduce').reset()
+ else:
+ # All-reduce if needed.
+ if args.DDP_impl == 'local':
+ timers('backward-allreduce').start()
+ model.allreduce_params(reduce_after=False,
+ fp32_allreduce=args.fp32_allreduce)
+ timers('backward-allreduce').stop()
+
+ if not args.deepspeed:
+ # Update master gradients.
+ timers('backward-master-grad').start()
+ if args.fp16:
+ optimizer.update_master_grads()
+ timers('backward-master-grad').stop()
+
+ # Clipping gradients helps prevent the exploding gradient.
+ timers('backward-clip-grad').start()
+ if args.clip_grad > 0:
+ if not args.fp16:
+ mpu.clip_grad_norm(model.parameters(), args.clip_grad)
+ else:
+ optimizer.clip_master_grads(args.clip_grad)
+ timers('backward-clip-grad').stop()
+
+
+def train_step(forward_step_func, data_iterator,
+ model, optimizer, lr_scheduler):
+ """Single training step."""
+ args = get_args()
+ timers = get_timers()
+
+ #see_memory_usage(f'before forward {model.global_steps}', force=True)
+ # Forward model for one step.
+ timers('forward').start()
+ loss, loss_reduced = forward_step_func(data_iterator, model)
+ timers('forward').stop()
+
+ #see_memory_usage(f'before backward {model.global_steps}', force=True)
+ # Calculate gradients, reduce across processes, and clip.
+ timers('backward').start()
+ backward_step(optimizer, model, loss)
+ timers('backward').stop()
+
+
+ #see_memory_usage(f'before optimizer {model.global_steps}', force=True)
+ # Update parameters.
+ skipped_iter = 0
+ timers('optimizer').start()
+ if args.deepspeed:
+ model.step()
+ else:
+ optimizer.step()
+ # Update learning rate.
+ if not (args.fp16 and optimizer.overflow):
+ lr_scheduler.step()
+ else:
+ skipped_iter = 1
+ timers('optimizer').stop()
+
+ return loss_reduced, skipped_iter
+
+
+def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
+ loss_scale, report_memory_flag, skipped_iter, model=None):
+ """Log training information such as losses, timing, ...."""
+ args = get_args()
+ timers = get_timers()
+ writer = get_tensorboard_writer()
+
+ # Update losses.
+ skipped_iters_key = 'skipped iterations'
+ total_loss_dict[skipped_iters_key] = total_loss_dict.get(
+ skipped_iters_key, 0) + skipped_iter
+ got_nan_key = 'got nan'
+
+ got_nan = False
+ for key in loss_dict:
+ if not skipped_iter:
+ total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
+ else:
+ value = loss_dict[key].float().sum().item()
+ is_nan = value == float('inf') or \
+ value == -float('inf') or \
+ value != value
+ got_nan = got_nan or is_nan
+
+ total_loss_dict[got_nan_key] = total_loss_dict.get(
+ got_nan_key, 0) + int(got_nan)
+
+ # Logging.
+ timers_to_log = []
+
+ def add_to_logging(name):
+ if name in timers.timers:
+ timers_to_log.append(name)
+ add_to_logging('forward')
+ add_to_logging('backward')
+ add_to_logging('backward-backward')
+ add_to_logging('backward-allreduce')
+ add_to_logging('backward-master-grad')
+ add_to_logging('backward-clip-grad')
+ add_to_logging('optimizer')
+ add_to_logging('batch generator')
+
+ # Tensorboard values.
+ if writer and torch.distributed.get_rank() == 0:
+ writer.add_scalar('tokens', args.tokens, iteration)
+ writer.add_scalar('learning_rate', learning_rate, iteration)
+ writer.add_scalar('learning_rate/vs tokens', learning_rate, args.tokens)
+ if args.curriculum_learning:
+ writer.add_scalar('seqlen',
+ args.curriculum_seqlen, iteration)
+ writer.add_scalar('seqlen/vs tokens',
+ args.curriculum_seqlen, args.tokens)
+ for key in loss_dict:
+ writer.add_scalar(key, loss_dict[key], iteration)
+ writer.add_scalar(key + '/vs tokens', loss_dict[key], args.tokens)
+ if args.fp16:
+ writer.add_scalar('loss_scale', loss_scale, iteration)
+ normalizer = iteration % args.log_interval
+ if normalizer == 0:
+ normalizer = args.log_interval
+ timers.write(timers_to_log, writer, iteration,
+ normalizer=normalizer)
+
+ if iteration % args.log_interval == 0:
+ elapsed_time = timers('interval time').elapsed()
+ if writer and torch.distributed.get_rank() == 0:
+ writer.add_scalar('iteration_time',
+ elapsed_time / args.log_interval, iteration)
+ log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
+ args.train_iters)
+ log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
+ elapsed_time * 1000.0 / args.log_interval)
+ log_string += ' learning rate: {:.3E} |'.format(learning_rate)
+ num_iterations = max(
+ 1, args.log_interval - total_loss_dict[skipped_iters_key])
+ for key in total_loss_dict:
+ if key not in [skipped_iters_key, got_nan_key]:
+ avg = total_loss_dict[key].item() / float(num_iterations)
+ log_string += ' {}: {:.6E} |'.format(key, avg)
+ total_loss_dict[key] = 0.0
+ if args.fp16:
+ log_string += ' loss scale: {:.1f} |'.format(loss_scale)
+ log_string += ' number of skipped iterations: {:3d} |'.format(
+ total_loss_dict[skipped_iters_key])
+ log_string += ' number of nan iterations: {:3d} |'.format(
+ total_loss_dict[got_nan_key])
+ total_loss_dict[skipped_iters_key] = 0
+ total_loss_dict[got_nan_key] = 0
+ print_rank_0(log_string)
+ if report_memory_flag:
+ report_memory('after {} iterations'.format(iteration))
+ report_memory_flag = False
+ timers.log(timers_to_log, normalizer=args.log_interval)
+ flops_calculator(model, args, elapsed_time)
+
+ return report_memory_flag
+
+
+def train(forward_step_func, model, optimizer, lr_scheduler,
+ train_data_iterator, valid_data_iterator):
+ """Train the model function."""
+ args = get_args()
+ timers = get_timers()
+
+ # Turn on training mode which enables dropout.
+ model.train()
+
+ # Tracking loss.
+ total_loss_dict = {}
+
+ # Iterations.
+ iteration = args.iteration
+
+ timers('interval time').start()
+ report_memory_flag = True
+ data_parallel_size = mpu.get_data_parallel_world_size()
+ global_batch_size = args.batch_size * data_parallel_size
+ while iteration < args.train_iters and \
+ (args.train_tokens is None or args.tokens < args.train_tokens):
+ loss_dict, skipped_iter = train_step(forward_step_func,
+ train_data_iterator,
+ model,
+ optimizer,
+ lr_scheduler)
+ iteration += 1
+ if args.curriculum_learning:
+ args.tokens += global_batch_size * args.curriculum_seqlen
+ else:
+ args.tokens += global_batch_size * args.seq_length
+
+ # Logging.
+ loss_scale = None
+ if args.fp16:
+ loss_scale = optimizer.cur_scale if args.deepspeed else optimizer.loss_scale
+ report_memory_flag = training_log(loss_dict, total_loss_dict,
+ optimizer.param_groups[0]['lr'],
+ iteration, loss_scale,
+ report_memory_flag, skipped_iter,
+ model=model)
+
+ # Autoresume
+ if args.adlr_autoresume and \
+ (iteration % args.adlr_autoresume_interval == 0):
+ check_adlr_autoresume_termination(iteration, model, optimizer,
+ lr_scheduler)
+
+ # Checkpointing
+ if args.save and args.save_interval and \
+ iteration % args.save_interval == 0:
+ save_checkpoint(iteration, model, optimizer, lr_scheduler)
+
+ # Evaluation
+ # XXX temporarily disabled for ZeRO-3
+ """
+ if args.eval_interval and iteration % args.eval_interval == 0 and \
+ args.do_valid:
+ prefix = 'iteration {}'.format(iteration)
+ evaluate_and_print_results(prefix, forward_step_func,
+ valid_data_iterator, model,
+ iteration, False)
+ """
+
+ if args.exit_interval and iteration % args.exit_interval == 0:
+ torch.distributed.barrier()
+ time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+ rank = torch.distributed.get_rank()
+ print_rank_0('rank: {} | time: {} | exiting the program at '
+ 'iteration {}'.format(rank, time_str, iteration))
+ sys.exit()
+
+ return iteration
+
+
+def evaluate(forward_step_func, data_iterator, model, verbose=False):
+ """Evaluation."""
+ args = get_args()
+
+ # Turn on evaluation mode which disables dropout.
+ model.eval()
+
+ total_loss_dict = {}
+
+ with torch.no_grad():
+ iteration = 0
+ while iteration < args.eval_iters:
+ iteration += 1
+ if verbose and iteration % args.log_interval == 0:
+ print_rank_0('Evaluating iter {}/{}'.format(iteration,
+ args.eval_iters))
+ # Forward evaluation.
+ _, loss_dict = forward_step_func(data_iterator, model)
+
+ # When contiguous memory optimizations are enabled, the buffers
+ # allocated by the optimizations are deallocated during backward pass
+ # in the absence of backward pass the buffers should be reset after each
+ # forward pass
+ if args.deepspeed and args.deepspeed_activation_checkpointing:
+ deepspeed.checkpointing.reset()
+
+ # Reduce across processes.
+ for key in loss_dict:
+ total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
+ loss_dict[key]
+ # Move model back to the train mode.
+ model.train()
+
+ for key in total_loss_dict:
+ total_loss_dict[key] /= args.eval_iters
+
+ return total_loss_dict
+
+
+def evaluate_and_print_results(prefix, forward_step_func,
+ data_iterator, model,
+ iteration, verbose=False):
+ """Helper function to evaluate and dump results on screen."""
+ writer = get_tensorboard_writer()
+ args = get_args()
+
+ total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
+ string = ' validation loss at {} | '.format(prefix)
+ for key in total_loss_dict:
+ string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
+ ppl = math.exp(min(20, total_loss_dict[key].item()))
+ string += '{} PPL: {:.6E} | '.format(key, ppl)
+ if writer and torch.distributed.get_rank() == 0:
+ writer.add_scalar('{} value'.format(key),
+ total_loss_dict[key].item(),
+ iteration)
+ writer.add_scalar('{} value/vs tokens'.format(key),
+ total_loss_dict[key].item(),
+ args.tokens)
+ writer.add_scalar('{} ppl'.format(key), ppl, iteration)
+ writer.add_scalar('{} ppl/vs tokens'.format(key), ppl, args.tokens)
+
+ length = len(string) + 1
+ print_rank_0('-' * length)
+ print_rank_0(string)
+ print_rank_0('-' * length)
+
+
+def build_train_valid_test_data_iterators(
+ build_train_valid_test_datasets_provider):
+ """XXX"""
+ args = get_args()
+
+ (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
+
+ print_rank_0('> building train, validation, and test datasets ...')
+ # Data loader only on rank 0 of each model parallel group.
+ if mpu.get_model_parallel_rank() == 0:
+ # Rank, size, and global batch size.
+ data_parallel_size = mpu.get_data_parallel_world_size()
+ global_batch_size = args.batch_size * data_parallel_size
+
+ # Number of train/valid/test samples.
+ train_iters = args.train_iters
+ eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
+ test_iters = args.eval_iters
+ train_val_test_num_samples = [train_iters * global_batch_size,
+ eval_iters * global_batch_size,
+ test_iters * global_batch_size]
+ print_rank_0(' > datasets target sizes (minimum size):')
+ print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
+ print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
+ print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
+
+ # Build the datasets.
+ train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
+ train_val_test_num_samples)
+
+ # Build dataloders.
+ train_dataloader = make_data_loader(train_ds)
+ valid_dataloader = make_data_loader(valid_ds)
+ test_dataloader = make_data_loader(test_ds)
+
+ # Flags to know if we need to do training/validation/testing.
+ do_train = train_dataloader is not None and args.train_iters > 0
+ do_valid = valid_dataloader is not None and args.eval_iters > 0
+ do_test = test_dataloader is not None and args.eval_iters > 0
+ # Need to broadcast num_tokens and num_type_tokens.
+ flags = torch.cuda.LongTensor(
+ [int(do_train), int(do_valid), int(do_test)])
+ else:
+ flags = torch.cuda.LongTensor([0, 0, 0])
+
+ # Broadcast num tokens.
+ torch.distributed.broadcast(flags,
+ mpu.get_model_parallel_src_rank(),
+ group=mpu.get_model_parallel_group())
+ args.do_train = flags[0].item()
+ args.do_valid = flags[1].item()
+ args.do_test = flags[2].item()
+
+ # Shift the start iterations.
+ if train_dataloader is not None:
+ train_dataloader.batch_sampler.start_iter = args.iteration % \
+ len(train_dataloader)
+ print_rank_0('setting training data start iteration to {}'.
+ format(train_dataloader.batch_sampler.start_iter))
+ if valid_dataloader is not None:
+ start_iter_val = (args.iteration // args.eval_interval) * \
+ args.eval_iters
+ valid_dataloader.batch_sampler.start_iter = start_iter_val % \
+ len(valid_dataloader)
+ print_rank_0('setting validation data start iteration to {}'.
+ format(valid_dataloader.batch_sampler.start_iter))
+
+ # Build iterators.
+ if train_dataloader is not None:
+ train_data_iterator = iter(train_dataloader)
+ else:
+ train_data_iterator = None
+
+ if valid_dataloader is not None:
+ valid_data_iterator = iter(valid_dataloader)
+ else:
+ valid_data_iterator = None
+
+ if test_dataloader is not None:
+ test_data_iterator = iter(test_dataloader)
+ else:
+ test_data_iterator = None
+
+ return train_data_iterator, valid_data_iterator, test_data_iterator
diff --git a/megatron_lm/megatron/utils.py b/megatron_lm/megatron/utils.py
new file mode 100644
index 0000000..86fcf5e
--- /dev/null
+++ b/megatron_lm/megatron/utils.py
@@ -0,0 +1,196 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""General utilities for Megatron."""
+
+import sys
+
+import torch
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_adlr_autoresume
+from megatron import mpu
+from megatron.checkpointing import save_checkpoint
+from megatron.data.samplers import DistributedBatchSampler
+from megatron.fp16 import FP16_Optimizer
+
+
+def reduce_losses(losses):
+ """Reduce a tensor of losses across all GPUs."""
+ reduced_losses = torch.cat(
+ [loss.clone().detach().view(1) for loss in losses])
+ torch.distributed.all_reduce(reduced_losses)
+ reduced_losses = reduced_losses / torch.distributed.get_world_size()
+
+ return reduced_losses
+
+
+def report_memory(name):
+ """Simple GPU memory report."""
+ mega_bytes = 1024.0 * 1024.0
+ string = name + ' memory (MB)'
+ string += ' | allocated: {}'.format(
+ torch.cuda.memory_allocated() / mega_bytes)
+ string += ' | max allocated: {}'.format(
+ torch.cuda.max_memory_allocated() / mega_bytes)
+ string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
+ string += ' | max reserved: {}'.format(
+ torch.cuda.max_memory_reserved() / mega_bytes)
+ print_rank_0(string)
+
+
+def print_params_min_max_norm(optimizer, iteration):
+ """Print min, max, and norm of all parameters."""
+ index = 0
+ rank = torch.distributed.get_rank()
+ string = 'iteration, rank, index, model-parallel,min, max, norm\n'
+ optimizer_ = optimizer
+ if isinstance(optimizer, FP16_Optimizer):
+ optimizer_ = optimizer.optimizer
+ for param_group in optimizer_.param_groups:
+ for param in param_group['params']:
+ index += 1
+ min_ = param.data.min()
+ max_ = param.data.max()
+ norm = param.data.norm()
+ string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
+ iteration, rank, index, int(param.model_parallel))
+ string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
+ print(string, flush=True)
+
+
+def check_adlr_autoresume_termination(iteration, model,
+ optimizer, lr_scheduler):
+ """Check for autoresume signal and exit if it is received."""
+ args = get_args()
+ autoresume = get_adlr_autoresume()
+ # Add barrier to ensure consistnecy.
+ torch.distributed.barrier()
+ if autoresume.termination_requested():
+ if args.save:
+ save_checkpoint(iteration, model, optimizer, lr_scheduler)
+ print_rank_0(">>> autoresume termination request found!")
+ if torch.distributed.get_rank() == 0:
+ autoresume.request_resume()
+ print_rank_0(">>> training terminated. Returning")
+ sys.exit(0)
+
+
+def make_data_loader(dataset):
+ """Buld dataloader given an input dataset."""
+ if dataset is None:
+ return None
+ args = get_args()
+
+ # Data parallel arguments.
+ world_size = mpu.get_data_parallel_world_size()
+ rank = mpu.get_data_parallel_rank()
+ global_batch_size = args.batch_size * world_size
+ num_workers = args.num_workers
+
+ # Use a simple sampler with distributed batch sampler.
+ sampler = torch.utils.data.SequentialSampler(dataset)
+ batch_sampler = DistributedBatchSampler(sampler=sampler,
+ batch_size=global_batch_size,
+ drop_last=True,
+ rank=rank,
+ world_size=world_size)
+ # Torch dataloader.
+ return torch.utils.data.DataLoader(dataset,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers,
+ pin_memory=True)
+
+
+def get_ltor_masks_and_position_ids(data,
+ eod_token,
+ reset_position_ids,
+ reset_attention_mask,
+ eod_mask_loss):
+ """Build masks and position id for left to right model."""
+
+ # Extract batch size and sequence length.
+ batch_size, seq_length = data.size()
+
+ # Attention mask (lower triangular).
+ if reset_attention_mask:
+ att_mask_batch = batch_size
+ else:
+ att_mask_batch = 1
+ attention_mask = torch.tril(torch.ones(
+ (att_mask_batch, seq_length, seq_length), device=data.device)).view(
+ att_mask_batch, 1, seq_length, seq_length)
+
+ # Loss mask.
+ loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
+ if eod_mask_loss:
+ loss_mask[data == eod_token] = 0.0
+
+ # Position ids.
+ position_ids = torch.arange(seq_length, dtype=torch.long,
+ device=data.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(data)
+ # We need to clone as the ids will be modifed based on batch index.
+ if reset_position_ids:
+ position_ids = position_ids.clone()
+
+ if reset_position_ids or reset_attention_mask:
+ # Loop through the batches:
+ for b in range(batch_size):
+
+ # Find indecies where EOD token is.
+ eod_index = position_ids[b, data[b] == eod_token]
+ # Detach indecies from positions if going to modify positions.
+ if reset_position_ids:
+ eod_index = eod_index.clone()
+
+ # Loop through EOD indecies:
+ prev_index = 0
+ for j in range(eod_index.size()[0]):
+ i = eod_index[j]
+ # Mask attention loss.
+ if reset_attention_mask:
+ attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
+ # Reset positions.
+ if reset_position_ids:
+ position_ids[b, (i + 1):] -= (i + 1 - prev_index)
+ prev_index = i + 1
+
+ # Convert attention mask to binary:
+ attention_mask = (attention_mask < 0.5)
+
+ return attention_mask, loss_mask, position_ids
+
+
+
+def get_parameters_in_billions(model):
+ gpus_per_model = torch.distributed.get_world_size(group=mpu.get_model_parallel_group())
+
+ approx_parameters_in_billions = sum([p.ds_numel if hasattr(p,'ds_id') else p.numel() for p in model.parameters()]) * gpus_per_model / 1000000000.0
+
+ return approx_parameters_in_billions
+
+
+def flops_calculator(model, args, iteration_time):
+ gpus_per_model = torch.distributed.get_world_size(group = mpu.get_model_parallel_group())
+
+ approx_parameters_in_billions = get_parameters_in_billions(model)
+
+ giga_flops_per_model_per_train_step = approx_parameters_in_billions * args.batch_size * args.seq_length * 2.0 * 4.0
+
+ effective_tera_flops_per_gpu = giga_flops_per_model_per_train_step / (iteration_time * 1000.0 * gpus_per_model)
+
+ print_rank_0(f"Effective Tera Flops per GPU: {round(effective_tera_flops_per_gpu, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B")
diff --git a/megatron_lm/pretrain_bert.py b/megatron_lm/pretrain_bert.py
new file mode 100644
index 0000000..b937b36
--- /dev/null
+++ b/megatron_lm/pretrain_bert.py
@@ -0,0 +1,123 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Pretrain BERT"""
+
+import torch
+import torch.nn.functional as F
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_timers
+from megatron import mpu
+from megatron.data.dataset_utils import build_train_valid_test_datasets
+from megatron.model import BertModel
+from megatron.training import pretrain
+from megatron.utils import reduce_losses
+
+
+def model_provider():
+ """Build the model."""
+
+ print_rank_0('building BERT model ...')
+
+ model = BertModel(
+ num_tokentypes=2,
+ add_binary_head=True,
+ parallel_output=True)
+
+ return model
+
+
+def get_batch(data_iterator):
+ """Build the batch."""
+
+ # Items and their type.
+ keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
+ datatype = torch.int64
+
+ # Broadcast data.
+ if data_iterator is not None:
+ data = next(data_iterator)
+ else:
+ data = None
+ data_b = mpu.broadcast_data(keys, data, datatype)
+
+ # Unpack.
+ tokens = data_b['text'].long()
+ types = data_b['types'].long()
+ sentence_order = data_b['is_random'].long()
+ loss_mask = data_b['loss_mask'].float()
+ lm_labels = data_b['labels'].long()
+ padding_mask = data_b['padding_mask'].long()
+
+ return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
+
+
+def forward_step(data_iterator, model):
+ """Forward step."""
+ args = get_args()
+ timers = get_timers()
+
+ # Get the batch.
+ timers('batch generator').start()
+ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
+ = get_batch(data_iterator)
+ timers('batch generator').stop()
+
+ # Forward model. lm_labels
+ lm_loss_, sop_logits = model(tokens, padding_mask,
+ tokentype_ids=types,
+ lm_labels=lm_labels)
+
+ sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
+ sentence_order.view(-1),
+ ignore_index=-1)
+
+ lm_loss = torch.sum(
+ lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
+
+ loss = lm_loss + sop_loss
+
+ reduced_losses = reduce_losses([lm_loss, sop_loss])
+
+ return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
+
+
+def train_valid_test_datasets_provider(train_val_test_num_samples):
+ """Build train, valid, and test datasets."""
+ args = get_args()
+
+ print_rank_0('> building train, validation, and test datasets '
+ 'for BERT ...')
+ train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
+ data_prefix=args.data_path,
+ data_impl=args.data_impl,
+ splits_string=args.split,
+ train_valid_test_num_samples=train_val_test_num_samples,
+ max_seq_length=args.seq_length,
+ masked_lm_prob=args.mask_prob,
+ short_seq_prob=args.short_seq_prob,
+ seed=args.seed,
+ skip_warmup=(not args.mmap_warmup))
+ print_rank_0("> finished creating BERT datasets ...")
+
+ return train_ds, valid_ds, test_ds
+
+
+if __name__ == "__main__":
+
+ pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
+ args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
diff --git a/megatron_lm/pretrain_gpt2.py b/megatron_lm/pretrain_gpt2.py
new file mode 100644
index 0000000..e12c4f2
--- /dev/null
+++ b/megatron_lm/pretrain_gpt2.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Pretrain GPT2"""
+
+import torch
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_timers
+from megatron import get_tokenizer
+from megatron import mpu
+from megatron.data.gpt2_dataset import build_train_valid_test_datasets
+from megatron.model import GPT2Model
+from megatron.training import pretrain
+from megatron.utils import get_ltor_masks_and_position_ids
+from megatron.utils import reduce_losses, get_parameters_in_billions
+
+
+import deepspeed
+from deepspeed.runtime.utils import see_memory_usage
+
+def model_provider(ds_init=True):
+ """Build the model."""
+
+ print_rank_0('building GPT2 model ...')
+ see_memory_usage(f"Before Building Model", force=True)
+ args = get_args()
+
+ if ds_init:
+ with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
+ remote_device=None if args.remote_device=='none' else args.remote_device,
+ config=args.deepspeed_config,
+ enabled=args.zero_stage==3):
+ model = GPT2Model(num_tokentypes=0, parallel_output=True)
+ else:
+ model = GPT2Model(num_tokentypes=0, parallel_output=True)
+
+ see_memory_usage(f"After Building Model", force=True)
+
+ if mpu.get_data_parallel_rank() == 0:
+ billion_params = get_parameters_in_billions(model)
+ print(f' > number of parameters on model parallel rank {mpu.get_model_parallel_rank()}\
+ {round(billion_params, 3)} Billion',
+ flush=True)
+
+ return model
+
+
+def get_batch(data_iterator):
+ """Generate a batch"""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ # Items and their type.
+ keys = ['text']
+ datatype = torch.int64
+
+ # Broadcast data.
+ if data_iterator is not None:
+ data = next(data_iterator)
+ else:
+ data = None
+ data_b = mpu.broadcast_data(keys, data, datatype)
+
+ # Unpack.
+ tokens_ = data_b['text'].long()
+ labels = tokens_[:, 1:].contiguous()
+ tokens = tokens_[:, :-1].contiguous()
+
+ # Get the masks and postition ids.
+ attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
+ tokens,
+ tokenizer.eod,
+ args.reset_position_ids,
+ args.reset_attention_mask,
+ args.eod_mask_loss)
+
+ return tokens, labels, loss_mask, attention_mask, position_ids
+
+
+def forward_step(data_iterator, model):
+ """Forward step."""
+ args = get_args()
+ timers = get_timers()
+
+ # Get the batch.
+ timers('batch generator').start()
+ tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
+ data_iterator)
+ timers('batch generator').stop()
+ # Forward model.
+ losses = model(tokens, position_ids, attention_mask, labels=labels)
+ if args.curriculum_learning and args.curriculum_seqlen < args.seq_length:
+ loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()
+ loss_mask = loss_mask.view(-1)
+ loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
+
+ # Reduce loss for logging.
+ reduced_loss = reduce_losses([loss])
+
+ return loss, {'lm loss': reduced_loss[0]}
+
+
+def train_valid_test_datasets_provider(train_val_test_num_samples):
+ """Build train, valid, and test datasets."""
+ args = get_args()
+
+ print_rank_0('> building train, validation, and test datasets '
+ 'for GPT2 ...')
+ train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
+ data_prefix=args.data_path,
+ data_impl=args.data_impl,
+ splits_string=args.split,
+ train_valid_test_num_samples=train_val_test_num_samples,
+ seq_length=args.seq_length,
+ seed=args.seed,
+ skip_warmup=(not args.mmap_warmup))
+ print_rank_0("> finished creating GPT2 datasets ...")
+
+ return train_ds, valid_ds, test_ds
+
+
+if __name__ == "__main__":
+
+ pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
+ args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
diff --git a/megatron_lm/pretrain_ict.py b/megatron_lm/pretrain_ict.py
new file mode 100644
index 0000000..05d6a9c
--- /dev/null
+++ b/megatron_lm/pretrain_ict.py
@@ -0,0 +1,138 @@
+# coding=utf-8
+# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Pretrain BERT for Inverse Cloze Task"""
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_timers
+from megatron import mpu
+from megatron.data.dataset_utils import build_train_valid_test_datasets
+from megatron.training import pretrain
+from megatron.utils import reduce_losses
+from megatron.model.realm_model import general_ict_model_provider
+from megatron.data.realm_dataset_utils import get_ict_batch
+
+
+def pretrain_ict_model_provider():
+ return general_ict_model_provider(False, False)
+
+
+def get_group_world_size_rank():
+
+ group = mpu.get_data_parallel_group()
+ rank = torch.distributed.get_rank(group=group)
+ world_size = torch.distributed.get_world_size(group=group)
+
+ return group, rank, world_size
+
+
+class AllgatherFromDataParallelRegion(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input_):
+ assert input_.dim() == 2
+ group, rank, world_size = get_group_world_size_rank()
+
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+ tensor_list[rank] = input_
+ torch.distributed.all_gather(tensor_list, input_, group=group)
+
+ output = torch.cat(tensor_list, dim=0).contiguous()
+
+ return output
+
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ group, rank, world_size = get_group_world_size_rank()
+
+ assert grad_output.shape[0] % world_size == 0
+ dim_size = grad_output.shape[0] // world_size
+ output_list = torch.split(grad_output, dim_size, dim=0)
+
+ # get chunk from this rank
+ output = output_list[rank].contiguous()
+ return output
+
+
+def forward_step(data_iterator, model):
+ """Forward step."""
+ args = get_args()
+ timers = get_timers()
+
+ # Get the batch.
+ timers('batch generator').start()
+ query_tokens, query_pad_mask, \
+ block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
+ timers('batch generator').stop()
+
+
+ # Forward model.
+ query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
+ local_batch_size = query_logits.shape[0]
+ global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1
+
+ all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
+ all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
+
+ # scores are inner products between query and block embeddings
+ retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
+ softmaxed = F.softmax(retrieval_scores, dim=1)
+ sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
+
+ def topk_accuracy(k):
+ return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
+
+ topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
+ retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
+ reduced_losses = reduce_losses([retrieval_loss, *topk_accs])
+
+ # create stats_dict with retrieval loss and all specified top-k accuracies
+ topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])}
+ stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict)
+
+ return retrieval_loss, stats_dict
+
+
+def train_valid_test_datasets_provider(train_val_test_num_samples):
+ """Build train, valid and test datasets."""
+ args = get_args()
+ print_rank_0('> building train, validation, and test datasets '
+ 'for BERT ICT...')
+
+ train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
+ data_prefix=args.data_path,
+ data_impl=args.data_impl,
+ splits_string=args.split,
+ train_valid_test_num_samples=train_val_test_num_samples,
+ max_seq_length=args.seq_length,
+ masked_lm_prob=args.mask_prob,
+ short_seq_prob=args.short_seq_prob,
+ seed=args.seed,
+ skip_warmup=(not args.mmap_warmup),
+ dataset_type='ict')
+ print_rank_0("> finished creating BERT ICT datasets ...")
+
+ return train_ds, valid_ds, test_ds
+
+
+if __name__ == "__main__":
+ pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
+ args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
diff --git a/megatron_lm/requirements.txt b/megatron_lm/requirements.txt
new file mode 100644
index 0000000..1f7389c
--- /dev/null
+++ b/megatron_lm/requirements.txt
@@ -0,0 +1,5 @@
+pybind11
+torch
+six
+regex
+numpy
diff --git a/megatron_lm/setup.py b/megatron_lm/setup.py
new file mode 100644
index 0000000..e074781
--- /dev/null
+++ b/megatron_lm/setup.py
@@ -0,0 +1,91 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Setup for pip package for Megatron."""
+
+import os
+import sys
+import setuptools
+
+if sys.version_info < (3,):
+ raise Exception("Python 2 is not supported by Megatron.")
+
+from megatron.package_info import (
+ __description__,
+ __contact_names__,
+ __url__,
+ __download_url__,
+ __keywords__,
+ __license__,
+ __package_name__,
+ __version__,
+)
+
+with open("README.md", "r") as fh:
+ long_description = fh.read()
+
+###############################################################################
+# Dependency Loading #
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
+
+
+def req_file(filename):
+ with open(filename) as f:
+ content = f.readlines()
+ return [x.strip() for x in content]
+
+
+install_requires = req_file("requirements.txt")
+
+setuptools.setup(
+ name=__package_name__,
+ # Versions should comply with PEP440. For a discussion on single-sourcing
+ # the version across setup.py and the project code, see
+ # https://packaging.python.org/en/latest/single_source_version.html
+ version=__version__,
+ description=__description__,
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ # The project's main homepage.
+ url=__url__,
+ author=__contact_names__,
+ maintainer=__contact_names__,
+ # The licence under which the project is released
+ license=__license__,
+ classifiers=[
+ 'Intended Audience :: Developers',
+ 'Intended Audience :: Science/Research',
+ 'Intended Audience :: Information Technology',
+ # Indicate what your project relates to
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
+ # Supported python versions
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ # Additional Setting
+ 'Environment :: Console',
+ 'Natural Language :: English',
+ 'Operating System :: OS Independent',
+ ],
+ python_requires='>=3.6',
+ packages=setuptools.find_packages(),
+ install_requires=install_requires,
+ # Add in any packaged data.
+ include_package_data=True,
+ zip_safe=False,
+ # PyPI package information.
+ keywords=__keywords__
+)
diff --git a/megatron_lm/tasks/data_utils.py b/megatron_lm/tasks/data_utils.py
new file mode 100644
index 0000000..866a5e6
--- /dev/null
+++ b/megatron_lm/tasks/data_utils.py
@@ -0,0 +1,118 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Tasks data utility."""
+
+import re
+import numpy as np
+
+
+def clean_text(text):
+ """Remove new lines and multiple spaces and adjust end of sentence dot."""
+
+ text = text.replace("\n", " ")
+ text = re.sub(r'\s+', ' ', text)
+ for _ in range(3):
+ text = text.replace(' . ', '. ')
+
+ return text
+
+
+def build_sample(ids, types, paddings, label, unique_id):
+ """Convert to numpy and return a sample consumed by the batch producer."""
+
+ ids_np = np.array(ids, dtype=np.int64)
+ types_np = np.array(types, dtype=np.int64)
+ paddings_np = np.array(paddings, dtype=np.int64)
+ sample = ({'text': ids_np,
+ 'types': types_np,
+ 'padding_mask': paddings_np,
+ 'label': int(label),
+ 'uid': int(unique_id)})
+
+ return sample
+
+
+def build_tokens_types_paddings_from_text(text_a, text_b,
+ tokenizer, max_seq_length):
+ """Build token types and paddings, trim if needed, and pad if needed."""
+
+ text_a_ids = tokenizer.tokenize(text_a)
+ text_b_ids = None
+ if text_b is not None:
+ text_b_ids = tokenizer.tokenize(text_b)
+
+ return build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids,
+ max_seq_length, tokenizer.cls,
+ tokenizer.sep, tokenizer.pad)
+
+
+def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
+ cls_id, sep_id, pad_id):
+ """Build token types and paddings, trim if needed, and pad if needed."""
+
+ ids = []
+ types = []
+ paddings = []
+
+ # [CLS].
+ ids.append(cls_id)
+ types.append(0)
+ paddings.append(1)
+
+ # A.
+ len_text_a = len(text_a_ids)
+ ids.extend(text_a_ids)
+ types.extend([0] * len_text_a)
+ paddings.extend([1] * len_text_a)
+
+ # [SEP].
+ ids.append(sep_id)
+ types.append(0)
+ paddings.append(1)
+
+ # B.
+ if text_b_ids is not None:
+ len_text_b = len(text_b_ids)
+ ids.extend(text_b_ids)
+ types.extend([1] * len_text_b)
+ paddings.extend([1] * len_text_b)
+
+ # Cap the size.
+ trimmed = False
+ if len(ids) >= max_seq_length:
+ max_seq_length_m1 = max_seq_length - 1
+ ids = ids[0:max_seq_length_m1]
+ types = types[0:max_seq_length_m1]
+ paddings = paddings[0:max_seq_length_m1]
+ trimmed = True
+
+ # [SEP].
+ if (text_b_ids is not None) or trimmed:
+ ids.append(sep_id)
+ if text_b_ids is None:
+ types.append(0)
+ else:
+ types.append(1)
+ paddings.append(1)
+
+ # Padding.
+ padding_length = max_seq_length - len(ids)
+ if padding_length > 0:
+ ids.extend([pad_id] * padding_length)
+ types.extend([pad_id] * padding_length)
+ paddings.extend([0] * padding_length)
+
+ return ids, types, paddings
diff --git a/megatron_lm/tasks/ensemble_classifier.py b/megatron_lm/tasks/ensemble_classifier.py
new file mode 100644
index 0000000..c2333b7
--- /dev/null
+++ b/megatron_lm/tasks/ensemble_classifier.py
@@ -0,0 +1,149 @@
+import os
+import argparse
+import collections
+
+import numpy as np
+import torch
+
+
+def process_files(args):
+ all_predictions = collections.OrderedDict()
+ all_labels = collections.OrderedDict()
+ all_uid = collections.OrderedDict()
+ for path in args.paths:
+ path = os.path.join(path, args.prediction_name)
+ try:
+ data = torch.load(path)
+ for dataset in data:
+ name, d = dataset
+ predictions, labels, uid = d
+ if name not in all_predictions:
+ all_predictions[name] = np.array(predictions)
+ if args.labels is None:
+ args.labels = [i for i in range(all_predictions[name].shape[1])]
+ if args.eval:
+ all_labels[name] = np.array(labels)
+ all_uid[name] = np.array(uid)
+ else:
+ all_predictions[name] += np.array(predictions)
+ assert np.allclose(all_uid[name], np.array(uid))
+ except Exception as e:
+ print(e)
+ continue
+ return all_predictions, all_labels, all_uid
+
+
+def get_threshold(all_predictions, all_labels, one_threshold=False):
+ if one_threshold:
+ all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
+ all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
+ out_thresh = []
+ for dataset in all_predictions:
+ preds = all_predictions[dataset]
+ labels = all_labels[dataset]
+ out_thresh.append(calc_threshold(preds, labels))
+ return out_thresh
+
+
+def calc_threshold(p, l):
+ trials = [(i) * (1. / 100.) for i in range(100)]
+ best_acc = float('-inf')
+ best_thresh = 0
+ for t in trials:
+ acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
+ if acc > best_acc:
+ best_acc = acc
+ best_thresh = t
+ return best_thresh
+
+
+def apply_threshold(preds, t):
+ assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
+ prob = preds[:, -1]
+ thresholded = (prob >= t).astype(int)
+ preds = np.zeros_like(preds)
+ preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
+ return preds
+
+
+def threshold_predictions(all_predictions, threshold):
+ if len(threshold) != len(all_predictions):
+ threshold = [threshold[-1]] * (len(all_predictions) - len(threshold))
+ for i, dataset in enumerate(all_predictions):
+ thresh = threshold[i]
+ preds = all_predictions[dataset]
+ all_predictions[dataset] = apply_threshold(preds, thresh)
+ return all_predictions
+
+
+def postprocess_predictions(all_predictions, all_labels, args):
+ for d in all_predictions:
+ all_predictions[d] = all_predictions[d] / len(args.paths)
+
+ if args.calc_threshold:
+ args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
+ print('threshold', args.threshold)
+
+ if args.threshold is not None:
+ all_predictions = threshold_predictions(all_predictions, args.threshold)
+
+ return all_predictions, all_labels
+
+
+def write_predictions(all_predictions, all_labels, all_uid, args):
+ all_correct = 0
+ count = 0
+ for dataset in all_predictions:
+ preds = all_predictions[dataset]
+ preds = np.argmax(preds, -1)
+ if args.eval:
+ correct = (preds == all_labels[dataset]).sum()
+ num = len(all_labels[dataset])
+ accuracy = correct / num
+ count += num
+ all_correct += correct
+ accuracy = (preds == all_labels[dataset]).mean()
+ print(accuracy)
+ if not os.path.exists(os.path.join(args.outdir, dataset)):
+ os.makedirs(os.path.join(args.outdir, dataset))
+ outpath = os.path.join(
+ args.outdir, dataset, os.path.splitext(
+ args.prediction_name)[0] + '.tsv')
+ with open(outpath, 'w') as f:
+ f.write('id\tlabel\n')
+ f.write('\n'.join(str(uid) + '\t' + str(args.labels[p])
+ for uid, p in zip(all_uid[dataset], preds.tolist())))
+ if args.eval:
+ print(all_correct / count)
+
+
+def ensemble_predictions(args):
+ all_predictions, all_labels, all_uid = process_files(args)
+ all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
+ write_predictions(all_predictions, all_labels, all_uid, args)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--paths', required=True, nargs='+',
+ help='paths to checkpoint directories used in ensemble')
+ parser.add_argument('--eval', action='store_true',
+ help='compute accuracy metrics against labels (dev set)')
+ parser.add_argument('--outdir',
+ help='directory to place ensembled predictions in')
+ parser.add_argument('--prediction-name', default='test_predictions.pt',
+ help='name of predictions in checkpoint directories')
+ parser.add_argument('--calc-threshold', action='store_true',
+ help='calculate threshold classification')
+ parser.add_argument('--one-threshold', action='store_true',
+ help='use on threshold for all subdatasets')
+ parser.add_argument('--threshold', nargs='+', default=None, type=float,
+ help='user supplied threshold for classification')
+ parser.add_argument('--labels', nargs='+', default=None,
+ help='whitespace separated list of label names')
+ args = parser.parse_args()
+ ensemble_predictions(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/megatron_lm/tasks/eval_utils.py b/megatron_lm/tasks/eval_utils.py
new file mode 100644
index 0000000..c89ea2c
--- /dev/null
+++ b/megatron_lm/tasks/eval_utils.py
@@ -0,0 +1,127 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Evaluation utilities."""
+
+import os
+import time
+
+import torch
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import mpu
+from tasks.finetune_utils import build_data_loader
+from tasks.finetune_utils import process_batch
+
+
+def accuracy_func_provider(single_dataset_provider):
+ """Provide function that calculates accuracies."""
+ args = get_args()
+
+ # Build dataloaders.
+ datapaths = args.valid_data
+ dataloaders = []
+ for datapath in datapaths:
+ dataset = single_dataset_provider(datapath)
+ dataloader = build_data_loader(
+ dataset, args.batch_size, num_workers=args.num_workers,
+ drop_last=(mpu.get_data_parallel_world_size() > 1))
+ dataloaders.append((dataset.dataset_name, dataloader))
+
+ def metrics_func(model, epoch, output_predictions=False):
+ print_rank_0('calculating metrics ...')
+ correct = 0
+ total = 0
+ if output_predictions:
+ assert mpu.get_data_parallel_world_size() == 1
+ named_predictions = []
+ names = 'predictions'
+ for name, dataloader in dataloaders:
+ output = calculate_correct_answers(name, model, dataloader,
+ epoch, output_predictions)
+ if not output_predictions:
+ correct_ans, total_count = output
+ else:
+ correct_ans, total_count, predictions = output
+ named_predictions.append((name, predictions))
+ names += '_' + name
+ correct += correct_ans
+ total += total_count
+ percent = float(correct) * 100.0 / float(total)
+ print_rank_0(' >> |epoch: {}| overall: correct / total = {} / {} = '
+ '{:.4f} %'.format(epoch, correct, total, percent))
+
+ if output_predictions and torch.distributed.get_rank() == 0:
+ assert args.load is not None
+ filename = os.path.join(args.load, names + '.pt')
+ torch.save(named_predictions, filename)
+
+ return metrics_func
+
+
+def calculate_correct_answers(name, model, dataloader,
+ epoch, output_predictions):
+ """Calculate correct over total answers and return prediction if the
+ `output_predictions` is true."""
+
+ start_time = time.time()
+ model.eval()
+ with torch.no_grad():
+ # For all the batches in the dataset.
+ total = 0
+ correct = 0
+ if output_predictions:
+ # This option is only possible when data parallel size is 1.
+ assert mpu.get_data_parallel_world_size() == 1
+ softmaxes = []
+ labels = []
+ ids = []
+ for _, batch in enumerate(dataloader):
+ # Run the model forward.
+ tokens, types, labels_, attention_mask = process_batch(batch)
+ logits = model(tokens, attention_mask, types)
+ # Add output predictions.
+ if output_predictions:
+ softmaxes.extend(torch.nn.Softmax(dim=-1)(
+ logits.float()).data.cpu().numpy().tolist())
+ labels.extend(labels_.data.cpu().numpy().tolist())
+ ids.extend(batch['uid'].cpu().numpy().tolist())
+ # Compute the correct answers.
+ predicted = torch.argmax(logits, dim=-1)
+ corrects = (predicted == labels_)
+ # Add to the counters.
+ total += labels_.size(0)
+ correct += corrects.sum().item()
+ model.train()
+
+ # Reduce.
+ unreduced = torch.cuda.LongTensor([correct, total])
+ torch.distributed.all_reduce(unreduced,
+ group=mpu.get_data_parallel_group())
+
+ # Print on screen.
+ correct_ans = unreduced[0].item()
+ total_count = unreduced[1].item()
+ percent = float(correct_ans) * 100.0 / float(total_count)
+ elapsed_time = time.time() - start_time
+ print_rank_0(' > |epoch: {}| metrics for {}: correct / total '
+ '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
+ epoch, name, correct_ans, total_count,
+ percent, elapsed_time))
+
+ if output_predictions:
+ return correct_ans, total_count, (softmaxes, labels, ids)
+ return correct_ans, total_count
diff --git a/megatron_lm/tasks/finetune_utils.py b/megatron_lm/tasks/finetune_utils.py
new file mode 100644
index 0000000..fc813f4
--- /dev/null
+++ b/megatron_lm/tasks/finetune_utils.py
@@ -0,0 +1,259 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Finetune utilities."""
+
+import torch
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_timers
+from megatron import mpu
+from megatron.checkpointing import load_checkpoint
+from megatron.checkpointing import save_checkpoint
+from megatron.training import evaluate_and_print_results
+from megatron.training import setup_model_and_optimizer
+from megatron.training import train_step
+from megatron.training import training_log
+from megatron.utils import check_adlr_autoresume_termination
+from megatron.utils import reduce_losses
+
+
+def process_batch(batch):
+ """Process batch and produce inputs for the model."""
+ args = get_args()
+
+ tokens = batch['text'].long().cuda().contiguous()
+ types = batch['types'].long().cuda().contiguous()
+ labels = batch['label'].long().cuda().contiguous()
+ attention_mask = batch['padding_mask'].float().cuda().contiguous()
+ if args.fp16:
+ attention_mask = attention_mask.half()
+
+ return tokens, types, labels, attention_mask
+
+
+def _cross_entropy_forward_step(batch, model):
+ """Simple forward step with cross-entropy loss."""
+ timers = get_timers()
+
+ # Get the batch.
+ timers('batch generator').start()
+ try:
+ batch_ = next(batch)
+ except BaseException:
+ batch_ = batch
+ tokens, types, labels, attention_mask = process_batch(batch_)
+ timers('batch generator').stop()
+
+ # Forward model.
+ logits = model(tokens, attention_mask, types)
+
+ # Cross-entropy loss.
+ loss_func = torch.nn.CrossEntropyLoss()
+ loss = loss_func(logits.contiguous().float(), labels)
+
+ # Reduce loss for logging.
+ reduced_loss = reduce_losses([loss])
+
+ return loss, {'lm loss': reduced_loss[0]}
+
+
+def build_data_loader(dataset, batch_size, num_workers, drop_last):
+ """Data loader. Note that batch-size is the local (per GPU) batch-size."""
+
+ # Sampler.
+ world_size = mpu.get_data_parallel_world_size()
+ rank = mpu.get_data_parallel_rank()
+ sampler = torch.utils.data.distributed.DistributedSampler(
+ dataset, num_replicas=world_size, rank=rank)
+
+ # Data loader. Note that batch size is the per GPU batch size.
+ data_loader = torch.utils.data.DataLoader(dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ shuffle=False,
+ num_workers=num_workers,
+ drop_last=drop_last,
+ pin_memory=True)
+
+ return data_loader
+
+
+def _build_infinite_size_dataloader(dataloader):
+ """Build a looped dataloader with infinite size."""
+
+ iterator = dataloader.__iter__()
+ while True:
+ try:
+ yield iterator.__next__()
+ except StopIteration:
+ iterator = dataloader.__iter__()
+
+
+def _build_train_valid_dataloaders(train_dataset, valid_dataset):
+ """Traing and validation dataloaders."""
+ args = get_args()
+
+ print_rank_0('building train and validation dataloaders ...')
+ # Training dataset.
+ train_dataloader = build_data_loader(train_dataset, args.batch_size,
+ args.num_workers, not args.keep_last)
+ # Set the training iterations.
+ args.train_iters_per_epoch = len(train_dataloader)
+ args.train_iters = args.epochs * args.train_iters_per_epoch
+ # Validation dataset. For this dataset, we do not need to set up
+ # shuffling so we can just use a simple infinite loop.
+ valid_dataloader_ = build_data_loader(valid_dataset, args.batch_size,
+ args.num_workers, not args.keep_last)
+ valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
+
+ return train_dataloader, valid_dataloader
+
+
+def _train(model, optimizer, lr_scheduler, forward_step,
+ train_dataloader, valid_dataloader, end_of_epoch_callback):
+ """Train the model."""
+ args = get_args()
+ timers = get_timers()
+
+ # Turn on training mode which enables dropout.
+ model.train()
+
+ # Tracking loss.
+ losses_dict_sum = {}
+
+ # Starting epoch and iteration
+ start_epoch = args.iteration // args.train_iters_per_epoch
+ start_iteration = args.iteration % args.train_iters_per_epoch
+ iteration = args.iteration
+
+ # Memory reporting flag.
+ report_memory_flag = True
+
+ # For each remaining epoch
+ timers('interval time').start()
+ for epoch in range(start_epoch, args.epochs):
+ print_rank_0('working on epoch {} ...'.format(epoch + 1))
+
+ # Set the data loader epoch to shuffle the index iterator.
+ train_dataloader.sampler.set_epoch(args.seed + epoch)
+
+ # For all the batches in the dataset.
+ for iteration_, batch in enumerate(train_dataloader):
+
+ # Ignore the iterations before starting value
+ if iteration_ < start_iteration:
+ continue
+ # Set to zero so the next epoch does not skip any batches.
+ start_iteration = 0
+
+ # Train for one step.
+ losses_dict, _ = train_step(forward_step, batch, model,
+ optimizer, lr_scheduler)
+ iteration += 1
+
+ # Logging.
+ report_memory_flag = training_log(losses_dict, losses_dict_sum,
+ optimizer.param_groups[0]['lr'],
+ iteration, optimizer.loss_scale,
+ report_memory_flag)
+
+ # Autoresume
+ if args.adlr_autoresume and \
+ (iteration % args.adlr_autoresume_interval == 0):
+ check_adlr_autoresume_termination(iteration, model,
+ optimizer, lr_scheduler)
+
+ # Checkpointing
+ if args.save and args.save_interval and \
+ iteration % args.save_interval == 0:
+ save_checkpoint(iteration, model, optimizer, lr_scheduler)
+
+ # Evaluation
+ if args.eval_interval and iteration % args.eval_interval == 0:
+ prefix = 'iteration {}'.format(iteration)
+ evaluate_and_print_results(prefix, forward_step,
+ valid_dataloader, model,
+ iteration, False)
+
+ # Checkpointing at the end of each epoch.
+ if args.save:
+ save_checkpoint(iteration, model, optimizer, lr_scheduler)
+
+ # Callback at the end of each epoch.
+ if end_of_epoch_callback is not None:
+ end_of_epoch_callback(model, epoch)
+
+
+def finetune(train_valid_datasets_provider, model_provider,
+ forward_step=_cross_entropy_forward_step,
+ end_of_epoch_callback_provider=None):
+ """Main finetune function used across all tasks."""
+ args = get_args()
+ timers = get_timers()
+
+ # Train and validation data loaders.
+ timers('train/valid/test dataset/dataloder').start()
+ if args.epochs > 0:
+ train_dataset, valid_dataset = train_valid_datasets_provider()
+ train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
+ train_dataset, valid_dataset)
+ timers('train/valid/test dataset/dataloder').stop()
+
+ # Build calback function.
+ timers('callback function').start()
+ end_of_epoch_callback = None
+ if end_of_epoch_callback_provider is not None:
+ end_of_epoch_callback = end_of_epoch_callback_provider()
+ timers('callback function').stop()
+
+ # Build model, optimizer and learning rate scheduler.
+ timers('model and optimizer').start()
+ model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
+ timers('model and optimizer').stop()
+
+ # If pretrained checkpoint is provided and we have not trained for
+ # any iteration (i.e., iteration is zero), then load the pretrained
+ # checkpoint.
+ timers('pretrained checkpoint').start()
+ if args.iteration == 0 and args.pretrained_checkpoint is not None:
+ original_load = args.load
+ args.load = args.pretrained_checkpoint
+ _ = load_checkpoint(model, None, None)
+ args.load = original_load
+ # This is critical when only model is loaded. We should make sure
+ # master parameters are also updated.
+ if args.fp16:
+ optimizer._model_params_to_master_params()
+ timers('pretrained checkpoint').stop()
+
+ # Print setup timing.
+ print_rank_0('done with setups ...')
+ timers.log(['train/valid/test dataset/dataloder', 'callback function',
+ 'model and optimizer', 'pretrained checkpoint'])
+ print_rank_0('training ...')
+
+ # Finetune the model.
+ if args.epochs > 0:
+ _train(model, optimizer, lr_scheduler, forward_step,
+ train_dataloader, valid_dataloader, end_of_epoch_callback)
+ # Or just evaluate.
+ else:
+ if end_of_epoch_callback is not None:
+ print_rank_0('evaluation only mode, setting epoch to -1')
+ end_of_epoch_callback(model, epoch=-1, output_predictions=True)
+
+ print_rank_0('done :-)')
diff --git a/megatron_lm/tasks/glue/data.py b/megatron_lm/tasks/glue/data.py
new file mode 100644
index 0000000..357ad13
--- /dev/null
+++ b/megatron_lm/tasks/glue/data.py
@@ -0,0 +1,69 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""GLUE dataset."""
+
+from abc import ABC
+from abc import abstractmethod
+
+from torch.utils.data import Dataset
+
+from megatron import print_rank_0
+from tasks.data_utils import build_sample
+from tasks.data_utils import build_tokens_types_paddings_from_text
+
+
+class GLUEAbstractDataset(ABC, Dataset):
+ """GLUE base dataset class."""
+
+ def __init__(self, task_name, dataset_name, datapaths,
+ tokenizer, max_seq_length):
+ # Store inputs.
+ self.task_name = task_name
+ self.dataset_name = dataset_name
+ self.tokenizer = tokenizer
+ self.max_seq_length = max_seq_length
+ print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
+ self.dataset_name))
+ # Process the files.
+ string = ' > paths:'
+ for path in datapaths:
+ string += ' ' + path
+ print_rank_0(string)
+ self.samples = []
+ for datapath in datapaths:
+ self.samples.extend(self.process_samples_from_single_path(datapath))
+ print_rank_0(' >> total number of samples: {}'.format(
+ len(self.samples)))
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, idx):
+ raw_sample = self.samples[idx]
+ ids, types, paddings = build_tokens_types_paddings_from_text(
+ raw_sample['text_a'], raw_sample['text_b'],
+ self.tokenizer, self.max_seq_length)
+ sample = build_sample(ids, types, paddings,
+ raw_sample['label'], raw_sample['uid'])
+ return sample
+
+ @abstractmethod
+ def process_samples_from_single_path(self, datapath):
+ """Abstract method that takes a single path / filename and
+ returns a list of dataset samples, each sample being a dict of
+ {'text_a': string, 'text_b': string, 'label': int, 'uid': int}
+ """
+ pass
diff --git a/megatron_lm/tasks/glue/finetune.py b/megatron_lm/tasks/glue/finetune.py
new file mode 100644
index 0000000..631f7ef
--- /dev/null
+++ b/megatron_lm/tasks/glue/finetune.py
@@ -0,0 +1,90 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""GLUE finetuning/evaluation."""
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_tokenizer
+from megatron.model.classification import Classification
+from tasks.eval_utils import accuracy_func_provider
+from tasks.finetune_utils import finetune
+
+
+def glue_classification(num_classes, Dataset,
+ name_from_datapath_func):
+
+ def train_valid_datasets_provider():
+ """Build train and validation dataset."""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ train_dataset = Dataset('training', args.train_data,
+ tokenizer, args.seq_length)
+ valid_dataset = Dataset('validation', args.valid_data,
+ tokenizer, args.seq_length)
+
+ return train_dataset, valid_dataset
+
+ def model_provider():
+ """Build the model."""
+ args = get_args()
+
+ print_rank_0('building classification model for {} ...'.format(
+ args.task))
+
+ return Classification(num_classes=num_classes, num_tokentypes=2)
+
+ def metrics_func_provider():
+ """Privde metrics callback function."""
+ def single_dataset_provider(datapath):
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ name = name_from_datapath_func(datapath)
+ return Dataset(name, [datapath], tokenizer, args.seq_length)
+ return accuracy_func_provider(single_dataset_provider)
+
+ """Finetune/evaluate."""
+ finetune(train_valid_datasets_provider, model_provider,
+ end_of_epoch_callback_provider=metrics_func_provider)
+
+
+def main():
+ args = get_args()
+
+ if args.task == 'MNLI':
+
+ num_classes = 3
+ from tasks.glue.mnli import MNLIDataset as Dataset
+
+ def name_from_datapath(datapath):
+ return datapath.split('MNLI')[-1].strip(
+ '.tsv').strip('/').replace('_', '-')
+
+ elif args.task == 'QQP':
+
+ num_classes = 2
+ from tasks.glue.qqp import QQPDataset as Dataset
+
+ def name_from_datapath(datapath):
+ return datapath.split('QQP')[-1].strip(
+ '.tsv').strip('/').replace('_', '-')
+
+ else:
+ raise NotImplementedError('GLUE task {} is not implemented.'.format(
+ args.task))
+
+ glue_classification(num_classes, Dataset, name_from_datapath)
diff --git a/megatron_lm/tasks/glue/mnli.py b/megatron_lm/tasks/glue/mnli.py
new file mode 100644
index 0000000..547a2a0
--- /dev/null
+++ b/megatron_lm/tasks/glue/mnli.py
@@ -0,0 +1,84 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""MNLI dataset."""
+
+from megatron import print_rank_0
+from tasks.data_utils import clean_text
+from .data import GLUEAbstractDataset
+
+
+LABELS = {'contradiction': 0, 'entailment': 1, 'neutral': 2}
+
+
+class MNLIDataset(GLUEAbstractDataset):
+
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
+ test_label='contradiction'):
+ self.test_label = test_label
+ super().__init__('MNLI', name, datapaths,
+ tokenizer, max_seq_length)
+
+ def process_samples_from_single_path(self, filename):
+ """"Implement abstract method."""
+ print_rank_0(' > Processing {} ...'.format(filename))
+
+ samples = []
+ total = 0
+ first = True
+ is_test = False
+ with open(filename, 'r') as f:
+ for line in f:
+ row = line.strip().split('\t')
+ if first:
+ first = False
+ if len(row) == 10:
+ is_test = True
+ print_rank_0(
+ ' reading {}, {} and {} columns and setting '
+ 'labels to {}'.format(
+ row[0].strip(), row[8].strip(),
+ row[9].strip(), self.test_label))
+ else:
+ print_rank_0(' reading {} , {}, {}, and {} columns '
+ '...'.format(
+ row[0].strip(), row[8].strip(),
+ row[9].strip(), row[-1].strip()))
+ continue
+
+ text_a = clean_text(row[8].strip())
+ text_b = clean_text(row[9].strip())
+ unique_id = int(row[0].strip())
+ label = row[-1].strip()
+ if is_test:
+ label = self.test_label
+
+ assert len(text_a) > 0
+ assert len(text_b) > 0
+ assert label in LABELS
+ assert unique_id >= 0
+
+ sample = {'text_a': text_a,
+ 'text_b': text_b,
+ 'label': LABELS[label],
+ 'uid': unique_id}
+ total += 1
+ samples.append(sample)
+
+ if total % 50000 == 0:
+ print_rank_0(' > processed {} so far ...'.format(total))
+
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
+ return samples
diff --git a/megatron_lm/tasks/glue/qqp.py b/megatron_lm/tasks/glue/qqp.py
new file mode 100644
index 0000000..a6adbd0
--- /dev/null
+++ b/megatron_lm/tasks/glue/qqp.py
@@ -0,0 +1,101 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""QQP dataset."""
+
+from megatron import print_rank_0
+from tasks.data_utils import clean_text
+from .data import GLUEAbstractDataset
+
+
+LABELS = [0, 1]
+
+
+class QQPDataset(GLUEAbstractDataset):
+
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
+ test_label=0):
+ self.test_label = test_label
+ super().__init__('QQP', name, datapaths,
+ tokenizer, max_seq_length)
+
+ def process_samples_from_single_path(self, filename):
+ """"Implement abstract method."""
+ print_rank_0(' > Processing {} ...'.format(filename))
+
+ samples = []
+ total = 0
+ first = True
+ is_test = False
+ with open(filename, 'r') as f:
+ for line in f:
+ row = line.strip().split('\t')
+ if first:
+ first = False
+ if len(row) == 3:
+ is_test = True
+ print_rank_0(' reading {}, {}, and {} columns and '
+ 'setting labels to {}'.format(
+ row[0].strip(), row[1].strip(),
+ row[2].strip(), self.test_label))
+ else:
+ assert len(row) == 6
+ print_rank_0(' reading {}, {}, {}, and {} columns'
+ ' ...'.format(
+ row[0].strip(), row[3].strip(),
+ row[4].strip(), row[5].strip()))
+ continue
+
+ if is_test:
+ assert len(row) == 3, 'expected length 3: {}'.format(row)
+ uid = int(row[0].strip())
+ text_a = clean_text(row[1].strip())
+ text_b = clean_text(row[2].strip())
+ label = self.test_label
+ assert len(text_a) > 0
+ assert len(text_b) > 0
+ else:
+ if len(row) == 6:
+ uid = int(row[0].strip())
+ text_a = clean_text(row[3].strip())
+ text_b = clean_text(row[4].strip())
+ label = int(row[5].strip())
+ else:
+ print_rank_0('***WARNING*** index error, '
+ 'skipping: {}'.format(row))
+ continue
+ if len(text_a) == 0:
+ print_rank_0('***WARNING*** zero length a, '
+ 'skipping: {}'.format(row))
+ continue
+ if len(text_b) == 0:
+ print_rank_0('***WARNING*** zero length b, '
+ 'skipping: {}'.format(row))
+ continue
+ assert label in LABELS
+ assert uid >= 0
+
+ sample = {'uid': uid,
+ 'text_a': text_a,
+ 'text_b': text_b,
+ 'label': label}
+ total += 1
+ samples.append(sample)
+
+ if total % 50000 == 0:
+ print_rank_0(' > processed {} so far ...'.format(total))
+
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
+ return samples
diff --git a/megatron_lm/tasks/main.py b/megatron_lm/tasks/main.py
new file mode 100644
index 0000000..d8a30d1
--- /dev/null
+++ b/megatron_lm/tasks/main.py
@@ -0,0 +1,69 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Main tasks functionality."""
+
+import os
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
+ os.path.pardir)))
+
+from megatron import get_args
+from megatron.initialize import initialize_megatron
+
+
+def get_tasks_args(parser):
+ """Provide extra arguments required for tasks."""
+ group = parser.add_argument_group(title='tasks')
+
+ group.add_argument('--task', type=str, required=True,
+ help='Task name.')
+ group.add_argument('--epochs', type=int, default=None,
+ help='Number of finetunning epochs. Zero results in '
+ 'evaluation only.')
+ group.add_argument('--pretrained-checkpoint', type=str, default=None,
+ help='Pretrained checkpoint used for finetunning.')
+ group.add_argument('--keep-last', action='store_true',
+ help='Keep the last batch (maybe incomplete) in'
+ 'the data loader')
+ group.add_argument('--train-data', nargs='+', default=None,
+ help='Whitespace separated paths or corpora names '
+ 'for training.')
+ group.add_argument('--valid-data', nargs='*', default=None,
+ help='path(s) to the validation data.')
+ group.add_argument('--overlapping-eval', type=int, default=32,
+ help='Sliding window for overlapping evaluation.')
+ group.add_argument('--strict-lambada', action='store_true',
+ help='Use more difficult formulation of lambada.')
+
+ return parser
+
+
+if __name__ == '__main__':
+
+ initialize_megatron(extra_args_provider=get_tasks_args)
+
+ args = get_args()
+ if args.task == 'RACE':
+ from race.finetune import main
+ elif args.task in ['MNLI', 'QQP']:
+ from glue.finetune import main
+ elif args.task in ['LAMBADA', 'WIKITEXT103']:
+ from zeroshot_gpt2.evaluate import main
+ else:
+ raise NotImplementedError('Task {} is not implemented.'.format(
+ args.task))
+
+ main()
diff --git a/megatron_lm/tasks/race/data.py b/megatron_lm/tasks/race/data.py
new file mode 100644
index 0000000..f11cad6
--- /dev/null
+++ b/megatron_lm/tasks/race/data.py
@@ -0,0 +1,131 @@
+
+import glob
+import json
+import os
+import time
+
+from torch.utils.data import Dataset
+
+from megatron import print_rank_0
+from tasks.data_utils import build_sample
+from tasks.data_utils import build_tokens_types_paddings_from_ids
+from tasks.data_utils import clean_text
+
+
+NUM_CHOICES = 4
+MAX_QA_LENGTH = 128
+
+
+class RaceDataset(Dataset):
+
+ def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length,
+ max_qa_length=MAX_QA_LENGTH):
+
+ self.dataset_name = dataset_name
+ print_rank_0(' > building RACE dataset for {}:'.format(
+ self.dataset_name))
+
+ string = ' > paths:'
+ for path in datapaths:
+ string += ' ' + path
+ print_rank_0(string)
+
+ self.samples = []
+ for datapath in datapaths:
+ self.samples.extend(process_single_datapath(datapath, tokenizer,
+ max_qa_length,
+ max_seq_length))
+
+ print_rank_0(' >> total number of samples: {}'.format(
+ len(self.samples)))
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, idx):
+ return self.samples[idx]
+
+
+def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
+ """Read in RACE files, combine, clean-up, tokenize, and convert to
+ samples."""
+
+ print_rank_0(' > working on {}'.format(datapath))
+ start_time = time.time()
+
+ # Get list of files.
+ filenames = glob.glob(os.path.join(datapath, '*.txt'))
+
+ samples = []
+ num_docs = 0
+ num_questions = 0
+ num_samples = 0
+ # Load all the files
+ for filename in filenames:
+ with open(filename, 'r') as f:
+ for line in f:
+ data = json.loads(line)
+ num_docs += 1
+
+ context = data["article"]
+ questions = data["questions"]
+ choices = data["options"]
+ answers = data["answers"]
+ # Check the length.
+ assert len(questions) == len(answers)
+ assert len(questions) == len(choices)
+
+ # Context: clean up and convert to ids.
+ context = clean_text(context)
+ context_ids = tokenizer.tokenize(context)
+
+ # Loop over questions.
+ for qi, question in enumerate(questions):
+ num_questions += 1
+ # Label.
+ label = ord(answers[qi]) - ord("A")
+ assert label >= 0
+ assert label < NUM_CHOICES
+ assert len(choices[qi]) == NUM_CHOICES
+
+ # For each question, build num-choices samples.
+ ids_list = []
+ types_list = []
+ paddings_list = []
+ for ci in range(NUM_CHOICES):
+ choice = choices[qi][ci]
+ # Merge with choice.
+ if "_" in question:
+ qa = question.replace("_", choice)
+ else:
+ qa = " ".join([question, choice])
+ # Clean QA.
+ qa = clean_text(qa)
+ # Tokenize.
+ qa_ids = tokenizer.tokenize(qa)
+ # Trim if needed.
+ if len(qa_ids) > max_qa_length:
+ qa_ids = qa_ids[0:max_qa_length]
+
+ # Build the sample.
+ ids, types, paddings \
+ = build_tokens_types_paddings_from_ids(
+ qa_ids, context_ids, max_seq_length,
+ tokenizer.cls, tokenizer.sep, tokenizer.pad)
+
+ ids_list.append(ids)
+ types_list.append(types)
+ paddings_list.append(paddings)
+
+ # Convert to numpy and add to samples
+ samples.append(build_sample(ids_list, types_list,
+ paddings_list, label,
+ num_samples))
+ num_samples += 1
+
+ elapsed_time = time.time() - start_time
+ print_rank_0(' > processed {} document, {} questions, and {} samples'
+ ' in {:.2f} seconds'.format(num_docs, num_questions,
+ num_samples, elapsed_time))
+
+ return samples
diff --git a/megatron_lm/tasks/race/finetune.py b/megatron_lm/tasks/race/finetune.py
new file mode 100644
index 0000000..c7bc53e
--- /dev/null
+++ b/megatron_lm/tasks/race/finetune.py
@@ -0,0 +1,63 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Race."""
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_tokenizer
+from megatron.model.multiple_choice import MultipleChoice
+from tasks.eval_utils import accuracy_func_provider
+from tasks.finetune_utils import finetune
+from tasks.race.data import RaceDataset
+
+
+def train_valid_datasets_provider():
+ """Provide train and validation datasets."""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ train_dataset = RaceDataset('training', args.train_data,
+ tokenizer, args.seq_length)
+ valid_dataset = RaceDataset('validation', args.valid_data,
+ tokenizer, args.seq_length)
+
+ return train_dataset, valid_dataset
+
+
+def model_provider():
+ """Build the model."""
+
+ print_rank_0('building multichoice model for RACE ...')
+
+ return MultipleChoice(num_tokentypes=2)
+
+
+def metrics_func_provider():
+ """Privde metrics callback function."""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ def single_dataset_provider(datapath):
+ name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
+ return RaceDataset(name, [datapath], tokenizer, args.seq_length)
+
+ return accuracy_func_provider(single_dataset_provider)
+
+
+def main():
+
+ finetune(train_valid_datasets_provider, model_provider,
+ end_of_epoch_callback_provider=metrics_func_provider)
diff --git a/megatron_lm/tasks/zeroshot_gpt2/datasets.py b/megatron_lm/tasks/zeroshot_gpt2/datasets.py
new file mode 100644
index 0000000..0d1f037
--- /dev/null
+++ b/megatron_lm/tasks/zeroshot_gpt2/datasets.py
@@ -0,0 +1,161 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Zero-shot datasets."""
+
+import json
+import math
+
+import numpy as np
+import torch
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_tokenizer
+from .detokenizer import get_detokenizer
+
+
+def build_dataset(task):
+ """Helper function to select and build dataset."""
+
+ if task == 'LAMBADA':
+ return _build_lambada_dataset()
+ if task == 'WIKITEXT103':
+ return _build_wikitext103_dataset()
+
+ raise NotImplementedError('dataset for {} task is not '
+ 'implemented.'.format(task))
+
+
+class _LMDataset(torch.utils.data.Dataset):
+
+ def __init__(self, tokens, seq_len, pad_idx, num_original_tokens,
+ num_tokenized_tokens, overalapping_eval=None):
+ self.tokens = tokens
+ self.seq_len = seq_len
+ self.pad_idx = pad_idx
+ self.overalapping_eval = overalapping_eval
+ if self.overalapping_eval is None:
+ self.overalapping_eval = self.seq_len
+ self.overalapping_eval = max(1, self.overalapping_eval)
+ self.num_original_tokens = num_original_tokens
+ self.num_tokenized_tokens = num_tokenized_tokens
+ self.total_targets = len(self.tokens) - 1
+ # remove first sequence tokens
+ targets = max(self.total_targets - self.overalapping_eval, 0)
+ self.total_sequences = max(
+ math.ceil(targets / self.overalapping_eval) + 1, 1)
+
+ def __len__(self):
+ return self.total_sequences
+
+ def __getitem__(self, idx):
+ start_idx = idx * self.overalapping_eval
+ end_idx = start_idx + self.seq_len
+ tokens = self.tokens[start_idx:end_idx + 1]
+ num_tokens = len(tokens)
+ pad_mask = [1] * num_tokens
+ if num_tokens < self.seq_len + 1:
+ num_pad = (self.seq_len + 1 - num_tokens)
+ pad_mask += [0] * (num_pad)
+ tokens += [self.pad_idx] * num_pad
+ pad_mask = np.array(pad_mask[1:])
+ if self.overalapping_eval != self.seq_len and idx != 0:
+ pad_mask[:-self.overalapping_eval] *= 0
+
+ return {'text': np.array(tokens), 'pad_mask': pad_mask}
+
+
+class _LambadaDataset(torch.utils.data.Dataset):
+
+ def __init__(self, path, pad_idx, tokenizer, seq_len, strict=False):
+ print_rank_0('> building lambada dataset from {} ...'.format(path))
+ self.seq_len = seq_len
+ self.pad_idx = pad_idx
+ self.tokenizer = tokenizer
+ self.strict = strict
+
+ self.tokens = []
+ self.labels = []
+ with open(path, 'r') as f:
+ for line in f.readlines():
+ text = json.loads(line)['text']
+ tokens, labels = self.get_tokens(text)
+ self.tokens.append(tokens)
+ self.labels.append(labels)
+
+ def get_tokens(self, text):
+ if not self.strict:
+ tokens = self.tokenizer.tokenize(text)
+ return tokens[:-1], [tokens[-1]]
+ last_token = text.split()[-1]
+ start_idx = text.rfind(last_token)
+ beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip())
+ last_token = self.tokenizer.tokenize(' ' + last_token)
+ return beginning_tokens, last_token
+
+ def __len__(self):
+ return len(self.tokens)
+
+ def __getitem__(self, idx):
+ tokens = self.tokens[idx]
+ num_tokens = len(tokens)
+ pad_mask = [0] * num_tokens
+ labels = self.labels[idx]
+ pad_mask += [1] * len(labels)
+ tokens = tokens + labels
+ num_tokens = len(tokens)
+ if num_tokens < self.seq_len + 1:
+ num_pad = (self.seq_len + 1 - num_tokens)
+ pad_mask += [0] * (num_pad)
+ tokens += [self.pad_idx] * num_pad
+ pad_mask = np.array(pad_mask[1:])
+
+ return {'text': np.array(tokens), 'pad_mask': pad_mask}
+
+
+def _build_lambada_dataset():
+ """Build lambada dataset."""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ assert len(args.valid_data) == 1
+ val_dataset = _LambadaDataset(args.valid_data[0], tokenizer.eod, tokenizer,
+ args.seq_length, args.strict_lambada)
+ print_rank_0(' > found {} samples.'.format(len(val_dataset)))
+
+ return val_dataset
+
+
+def _build_wikitext103_dataset():
+ """"""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ assert len(args.valid_data) == 1
+ with open(args.valid_data[0], "rb") as reader:
+ entire_data = reader.read().decode('utf-8')
+ num_original_tokens = len(entire_data.strip().split(" "))
+ entire_data = get_detokenizer(args.valid_data[0])(entire_data)
+ tokenized_data = tokenizer.tokenize(entire_data)
+ num_tokenized_tokens = len(tokenized_data)
+
+ val_dataset = _LMDataset(tokenized_data, args.seq_length, tokenizer.eod,
+ num_original_tokens, num_tokenized_tokens,
+ args.overlapping_eval)
+ print_rank_0(' > number of original tokens: {}, number of detokenized '
+ 'tokens: {}'.format(num_original_tokens, num_tokenized_tokens))
+
+ return val_dataset
diff --git a/megatron_lm/tasks/zeroshot_gpt2/detokenizer.py b/megatron_lm/tasks/zeroshot_gpt2/detokenizer.py
new file mode 100644
index 0000000..2bc8728
--- /dev/null
+++ b/megatron_lm/tasks/zeroshot_gpt2/detokenizer.py
@@ -0,0 +1,80 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Detokenization."""
+
+import re
+
+
+def ptb_detokenizer(string):
+ string = string.replace(" '", "'")
+ string = string.replace(" \n", "\n")
+ string = string.replace("\n ", "\n")
+ string = string.replace(" n't", "n't")
+ string = string.replace(" N ", "1 ")
+ string = string.replace("$ 1", "$1")
+ string = string.replace("# 1", "#1")
+ return string
+
+
+def wikitext_detokenizer(string):
+ # contractions
+ string = string.replace("s '", "s'")
+ string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
+ # number separators
+ string = string.replace(" @-@ ", "-")
+ string = string.replace(" @,@ ", ",")
+ string = string.replace(" @.@ ", ".")
+ # punctuation
+ string = string.replace(" : ", ": ")
+ string = string.replace(" ; ", "; ")
+ string = string.replace(" . ", ". ")
+ string = string.replace(" ! ", "! ")
+ string = string.replace(" ? ", "? ")
+ string = string.replace(" , ", ", ")
+ # double brackets
+ string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
+ string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
+ string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
+ string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
+ string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
+ # miscellaneous
+ string = string.replace("= = = =", "====")
+ string = string.replace("= = =", "===")
+ string = string.replace("= =", "==")
+ string = string.replace(" " + chr(176) + " ", chr(176))
+ string = string.replace(" \n", "\n")
+ string = string.replace("\n ", "\n")
+ string = string.replace(" N ", " 1 ")
+ string = string.replace(" 's", "'s")
+
+ return string
+
+
+def lambada_detokenizer(string):
+ return string
+
+
+_DETOKENIZERS = {
+ 'ptb': ptb_detokenizer,
+ 'wiki': wikitext_detokenizer,
+ 'lambada': lambada_detokenizer,
+}
+
+
+def get_detokenizer(path):
+ for key in _DETOKENIZERS.keys():
+ if key in path:
+ return _DETOKENIZERS[key]
diff --git a/megatron_lm/tasks/zeroshot_gpt2/evaluate.py b/megatron_lm/tasks/zeroshot_gpt2/evaluate.py
new file mode 100644
index 0000000..b1c06d2
--- /dev/null
+++ b/megatron_lm/tasks/zeroshot_gpt2/evaluate.py
@@ -0,0 +1,195 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""GPT2 zero-shot evaluation."""
+
+import math
+
+import torch
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_tokenizer
+from megatron import mpu
+from megatron.checkpointing import load_checkpoint
+from megatron.model import GPT2Model
+from megatron.training import get_model
+from megatron.utils import get_ltor_masks_and_position_ids
+from tasks.finetune_utils import build_data_loader
+
+from .datasets import build_dataset
+
+
+def get_model_provider(eval_metric):
+ """Based on evaluation metric set the parallel-output flag and
+ return the model provider."""
+
+ def model_provider():
+ """Build the model."""
+
+ if eval_metric == 'loss':
+ parallel_output = True
+ elif eval_metric == 'accuracy':
+ parallel_output = False
+ else:
+ raise NotImplementedError('output type for {} evaluation metric '
+ 'is not supported.'.format(eval_metric))
+
+ print_rank_0('building GPT2 model ...')
+ model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output)
+
+ return model
+
+ return model_provider
+
+
+def process_batch(batch):
+ """Process batch and produce inputs for the model."""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ loss_mask = batch['pad_mask'].long().cuda().contiguous().byte()
+ tokens_ = batch['text'].long().cuda().contiguous()
+ labels = tokens_[:, 1:].contiguous()
+ tokens = tokens_[:, :-1].contiguous()
+
+ # Get the masks and postition ids.
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
+ tokens,
+ tokenizer.eod,
+ args.reset_position_ids,
+ args.reset_attention_mask,
+ args.eod_mask_loss)
+
+ return tokens, labels, attention_mask, position_ids, loss_mask
+
+
+def forward_step(batch, model, eval_metric):
+ """Forward step."""
+
+ # Get the batch.
+ tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
+ batch)
+
+ # Forward model.
+ output = model(tokens, position_ids, attention_mask)
+
+ # For loss, return the unreduced loss.
+ if eval_metric == 'loss':
+ losses = mpu.vocab_parallel_cross_entropy(
+ output.contiguous().float(), labels.contiguous())
+ loss = torch.sum(
+ losses.view(-1) * loss_mask.contiguous().view(-1).float())
+ return loss
+
+ # For accuracy, return the number of correctly predicted samples.
+ if eval_metric == 'accuracy':
+ outputs = torch.argmax(output, -1)
+ correct = (outputs == labels).float()
+ correct[(1 - loss_mask).bool()] = 1
+ correct = correct.prod(-1)
+ return correct.sum()
+
+ raise NotImplementedError('forward method for evaluation metric {} '
+ 'is not implemented.'.format(eval_metric))
+
+
+def evaluate(data_loader, model, eval_metric):
+ """Evaluation."""
+ args = get_args()
+
+ # Turn on evaluation mode which disables dropout.
+ model.eval()
+
+ total_output = 0.0
+ with torch.no_grad():
+ # For all the batches in the dataset.
+ for iteration, batch in enumerate(data_loader):
+ if iteration % args.log_interval == 0:
+ print_rank_0('> working on iteration: {}'.format(iteration))
+ # Forward evaluation.
+ output = forward_step(batch, model, eval_metric)
+
+ # Reduce across processes.
+ torch.distributed.all_reduce(output,
+ group=mpu.get_data_parallel_group())
+
+ total_output += output
+
+ return total_output
+
+
+def evaluate_and_print_results(task, data_loader, model, eval_metric):
+ """Evaluate and print results on screen."""
+
+ # Evaluate and get results.
+ output = evaluate(data_loader, model, eval_metric)
+
+ string = ' validation results on {} | '.format(task)
+ if eval_metric == 'loss':
+ num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
+ num_original_tokens = data_loader.dataset.num_original_tokens
+ val_loss = output / (num_tokenized_tokens - 1)
+ ppl = math.exp(min(20, val_loss))
+ token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
+ adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
+ string += 'avg loss: {:.4E} | '.format(val_loss)
+ string += 'ppl: {:.4E} | '.format(ppl)
+ string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
+ string += 'token ratio: {} |'.format(token_ratio)
+
+ elif eval_metric == 'accuracy':
+ num_examples = len(data_loader.dataset)
+ acc = output / num_examples
+ string += 'number correct: {:.4E} | '.format(output)
+ string += 'total examples: {:.4E} | '.format(num_examples)
+ string += 'avg accuracy: {:.4E}'.format(acc)
+
+ else:
+ raise NotImplementedError('evaluation method for {} metric is not '
+ 'implemented yet.'.format(eval_metric))
+
+ length = len(string) + 1
+ print_rank_0('-' * length)
+ print_rank_0(string)
+ print_rank_0('-' * length)
+
+
+def main():
+ """Main program."""
+ args = get_args()
+
+ if args.task == 'LAMBADA':
+ eval_metric = 'accuracy'
+ elif args.task == 'WIKITEXT103':
+ eval_metric = 'loss'
+ else:
+ raise NotImplementedError('{} task is not implemented.'.format(
+ args.task))
+
+ # Set up model and load checkpoint.
+ model = get_model(get_model_provider(eval_metric))
+ if args.load is not None:
+ _ = load_checkpoint(model, None, None)
+
+ # Data stuff.
+ dataset = build_dataset(args.task)
+ dataloader = build_data_loader(dataset, args.batch_size,
+ args.num_workers, drop_last=False)
+
+ # Run evaluation.
+ evaluate_and_print_results(args.task, dataloader, model, eval_metric)
+
+ print_rank_0('done :-)')
diff --git a/megatron_lm/tools/create_doc_index.py b/megatron_lm/tools/create_doc_index.py
new file mode 100644
index 0000000..1e14d1d
--- /dev/null
+++ b/megatron_lm/tools/create_doc_index.py
@@ -0,0 +1,30 @@
+import sys
+sys.path.append('../')
+
+from megatron.indexer import IndexBuilder
+from megatron.initialize import initialize_megatron
+
+
+def main():
+ """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
+ - Include all args needed for initial model specification
+
+ Other key args:
+ --block-data-path: path to write to
+ --ict-load or --realm-load: path to checkpoint with which to embed
+ --data-path and --titles-data-path: paths for dataset
+ --indexer-log-interval: reporting interval
+ --indexer-batch-size: size specific for indexer jobs
+
+ Check README.md for example script
+ """
+
+ initialize_megatron(extra_args_provider=None,
+ args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
+ index_builder = IndexBuilder()
+ index_builder.build_and_save_index()
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/megatron_lm/tools/generate_samples_gpt2.py b/megatron_lm/tools/generate_samples_gpt2.py
new file mode 100644
index 0000000..93e1e5e
--- /dev/null
+++ b/megatron_lm/tools/generate_samples_gpt2.py
@@ -0,0 +1,104 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Sample Generate GPT2"""
+
+import os
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
+ os.path.pardir)))
+
+from megatron import get_args
+from megatron import print_rank_0
+from megatron import get_tokenizer
+from megatron.checkpointing import load_checkpoint
+from megatron.initialize import initialize_megatron
+from megatron.model import GPT2Model
+from megatron.training import get_model
+from megatron.text_generation_utils import generate_and_write_samples_unconditional
+from megatron.text_generation_utils import generate_samples_input_from_file
+from megatron.text_generation_utils import generate_samples_interactive
+
+
+def model_provider():
+ """Build the model."""
+
+ print_rank_0('building GPT2 model ...')
+ model = GPT2Model(num_tokentypes=0, parallel_output=False)
+
+ return model
+
+
+def add_text_generate_args(parser):
+ """Text generation arguments."""
+ group = parser.add_argument_group(title='text generation')
+
+ group.add_argument("--temperature", type=float, default=1.0,
+ help='Sampling temperature.')
+ group.add_argument("--greedy", action='store_true', default=False,
+ help='Use greedy sampling.')
+ group.add_argument("--top_p", type=float, default=0.0,
+ help='Top p sampling.')
+ group.add_argument("--top_k", type=int, default=0,
+ help='Top k sampling.')
+ group.add_argument("--out-seq-length", type=int, default=1024,
+ help='Size of the output generated text.')
+ group.add_argument("--sample-input-file", type=str, default=None,
+ help='Get input from file instead of interactive mode, '
+ 'each line is json.')
+ group.add_argument("--sample-output-file", type=str, default=None,
+ help='Output file got from --sample-input-file')
+ group.add_argument("--sample-context-field", type=str, default="prefix",
+ help='Field name of context in --sample-input-file lines.')
+ group.add_argument("--sample-generated-field", type=str, default="suffix",
+ help='Field name of generated text written to --sample-output-file lines.')
+ group.add_argument("--num-samples", type=int, default=0,
+ help='Number of samples to generate unconditionally, '
+ 'defaults to 0 and interactive conditional sampling')
+ group.add_argument("--genfile", type=str,
+ help='Output file when generating unconditionally')
+ group.add_argument("--recompute", action='store_true',
+ help='During generation recompute all attention '
+ 'instead of using previously computed keys/values.')
+
+ return parser
+
+
+def main():
+ """Main program."""
+
+ initialize_megatron(extra_args_provider=add_text_generate_args,
+ args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
+
+ # Set up model and load checkpoint.
+ model = get_model(model_provider)
+ args = get_args()
+ if args.load is not None:
+ _ = load_checkpoint(model, None, None)
+
+ # Generate samples.
+ if args.num_samples == 0:
+ args.batch_size = 1
+ if args.sample_input_file != None:
+ generate_samples_input_from_file(model)
+ else:
+ generate_samples_interactive(model)
+ else:
+ generate_and_write_samples_unconditional(model)
+
+
+if __name__ == "__main__":
+
+ main()
diff --git a/megatron_lm/tools/linter.py b/megatron_lm/tools/linter.py
new file mode 100644
index 0000000..5b14007
--- /dev/null
+++ b/megatron_lm/tools/linter.py
@@ -0,0 +1,36 @@
+import os
+import os.path as osp
+import pathlib
+import subprocess
+
+
+def recursively_lint_files():
+ """Recursively lint all python files in chosen subdirectories of megatron-lm"""
+
+ try:
+ import autopep8
+ except ModuleNotFoundError:
+ print("Please first install autopep8 via `pip install autopep8`")
+ return
+
+ # get all python file paths from top level directory
+ file_dir = str(pathlib.Path(__file__).parent.absolute())
+ working_dir = osp.join(file_dir, os.pardir)
+ all_py_paths = set(os.path.join(working_dir, fname)
+ for fname in os.listdir(working_dir) if ".py" in fname)
+
+ # get all python file paths from chosen subdirectories
+ check_dirs = ['docker', 'megatron', 'openwebtext', 'scripts', 'tasks']
+ for sub_dir in check_dirs:
+ for path, _, fnames in os.walk(osp.join(working_dir, sub_dir)):
+ all_py_paths.update(set(osp.join(path, fname) for fname in fnames if ".py" in fname))
+
+ print("Linting the following: ")
+ for py_path in all_py_paths:
+ print(py_path)
+ command = 'autopep8 --max-line-length 100 --aggressive --in-place {}'.format(py_path)
+ subprocess.check_call(command)
+
+
+if __name__ == "__main__":
+ recursively_lint_files()
diff --git a/megatron_lm/tools/merge_mp_partitions.py b/megatron_lm/tools/merge_mp_partitions.py
new file mode 100644
index 0000000..4639adc
--- /dev/null
+++ b/megatron_lm/tools/merge_mp_partitions.py
@@ -0,0 +1,290 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Merge model parallel partitions."""
+
+import os
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
+ os.path.pardir)))
+
+import torch
+
+from megatron import mpu
+from megatron.checkpointing import ensure_directory_exists
+from megatron.checkpointing import get_checkpoint_name
+from megatron.checkpointing import get_checkpoint_tracker_filename
+from megatron.global_vars import rebuild_tokenizer
+from megatron.global_vars import _parse_args
+
+
+def split_into_partitions(tensor, num_partitions, partition_dim, stride):
+
+ per_partition_size = mpu.utils.divide(tensor.size(partition_dim),
+ num_partitions)
+ per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
+
+ partitions_list = torch.split(tensor,
+ per_partition_per_stride_size,
+ dim=partition_dim)
+
+ partitions = []
+ for i in range(num_partitions):
+ partition = torch.cat(partitions_list[i::num_partitions],
+ dim=partition_dim)
+ partitions.append(partition)
+
+ return partitions
+
+
+def merge_partitions(merged, partitions, partition_dim, stride):
+
+ # Number and size of each partition.
+ num_partitions = len(partitions)
+ per_partition_size = None
+ for partition in partitions:
+ if per_partition_size is None:
+ per_partition_size = partition.size(partition_dim)
+ else:
+ assert per_partition_size == partition.size(partition_dim)
+
+ def concat_partitions(partitions_):
+ with torch.no_grad():
+ if (per_partition_size * num_partitions) == merged.size(
+ partition_dim):
+ torch.cat(partitions_, dim=partition_dim, out=merged)
+ else:
+ print(' ***WARNING*** sizes do not match. Will cut '
+ 'the merged partitions by {} along dimension {} '
+ 'to reduce the size from {} to {} ...'.format(
+ (per_partition_size * num_partitions) - \
+ merged.size(partition_dim), partition_dim,
+ per_partition_size * num_partitions,
+ merged.size(partition_dim)))
+ merged_ = torch.cat(partitions_, dim=partition_dim)
+ merged_split = torch.split(merged_, merged.size(partition_dim),
+ dim=partition_dim)
+ merged_ = merged_split[0]
+ assert merged_.size(partition_dim) == merged.size(partition_dim)
+ merged.data.copy_(merged_.data)
+
+ # If stride is 1, then do simple concatination.
+ if stride == 1:
+ concat_partitions(partitions)
+ return
+
+ # For none unity strides, first split based on stride and then group.
+ per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
+ # Chunk and build a list.
+ chunks = None
+ for i, partition in enumerate(partitions):
+ chunk = torch.split(partition,
+ per_partition_per_stride_size,
+ dim=partition_dim)
+
+ if chunks is None:
+ chunks = [0]*(num_partitions*len(chunk))
+ chunks[i::num_partitions] = chunk
+
+ # Concatinate.
+ concat_partitions(chunks)
+
+ return
+
+
+def get_model(model_type):
+
+ if model_type == 'BERT':
+ from pretrain_bert import model_provider
+ elif model_type == 'GPT2':
+ from pretrain_gpt2 import model_provider
+ elif model_type == 'RACE':
+ from tasks.race.finetune import model_provider
+ elif model_type == ['MNLI', 'QQP']:
+ num_classes = 2
+ if model_type == 'MNLI':
+ num_classes = 3
+ from megatron.model.classification import Classification
+ def model_provider():
+ return Classification(num_classes=num_classes, num_tokentypes=2)
+ else:
+ raise Exception('unrecognized model type: {}'.format(model_type))
+
+ model = model_provider(ds_init=False)
+ model = model.half()
+
+ return model
+
+
+def get_parallel_checkpoint_name(path):
+
+ tracker_filename = get_checkpoint_tracker_filename(path)
+ iteration = 0
+ with open(tracker_filename, 'r') as f:
+ metastring = f.read().strip()
+ iteration = int(metastring)
+ # assert iteration > 0
+ checkpoint_name = get_checkpoint_name(path, iteration)
+
+ return checkpoint_name, iteration
+
+
+def test_split_merge():
+
+ print('testing split and merge ...')
+
+ #[QKV.ROW-COL]
+ tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15],
+ [1.21, 1.22, 1.23, 1.24, 1.25],
+ [1.31, 1.32, 1.33, 1.34, 1.35],
+ [1.41, 1.42, 1.43, 1.44, 1.45],
+ [2.11, 2.12, 2.13, 2.14, 2.15],
+ [2.21, 2.22, 2.23, 2.24, 2.25],
+ [2.31, 2.32, 2.33, 2.34, 2.35],
+ [2.41, 2.42, 2.43, 2.44, 2.45],
+ [3.11, 3.12, 3.13, 3.14, 3.15],
+ [3.21, 3.22, 3.23, 3.24, 3.25],
+ [3.31, 3.32, 3.33, 3.34, 3.35],
+ [3.41, 3.42, 3.43, 3.44, 3.45]])
+
+ num_partitions = 2
+ partition_dim = 0
+ stride = 3
+ partitions = split_into_partitions(tensor, num_partitions,
+ partition_dim, stride)
+
+ merged = torch.zeros_like(tensor)
+ merge_partitions(merged, partitions, partition_dim, stride)
+
+ max_error = (merged - tensor).abs().max()
+ print(' > max error (should be zero): {}'.format(max_error))
+
+
+def get_mp_merge_args(parser):
+ """Provide extra arguments required for merging."""
+ group = parser.add_argument_group(title='mp merge')
+
+ group.add_argument('--model-type', type=str, required=True,
+ choices=['BERT', 'GPT2', 'RACE', 'MNLI', 'QQP'],
+ help='Type of the mdoel.')
+
+ return parser
+
+
+def main():
+
+ # Args
+ from megatron.initialize import initialize_megatron
+ from megatron import get_args
+ initialize_megatron(extra_args_provider=get_mp_merge_args)
+ args = get_args()
+ # args = _parse_args(extra_args_provider=get_mp_merge_args)
+ model_type = args.model_type
+ orig_model_parallel_size = args.model_parallel_size
+ args.model_parallel_size = 1
+ tokenizer = rebuild_tokenizer(args)
+
+ print('\n merging model parallel partitions ...')
+ print(' > number of partitions: {}'.format(orig_model_parallel_size))
+ print(' > checkpoint path: {}'.format(args.load))
+ print(' > model parameters:')
+ print(' number of tokens ................ {} '.format(
+ tokenizer.vocab_size))
+ print(' number of layers ................ {}'.format(args.num_layers))
+ print(' hidden sise ..................... {}'.format(args.hidden_size))
+ print(' number of attention heads ....... {}'.format(
+ args.num_attention_heads))
+ print(' maximum position embeddings ..... {}'.format(
+ args.max_position_embeddings))
+
+ # Full model.
+ print('> building the full model ...')
+ mpu.initialize.set_model_parallel_world_size(1)
+ mpu.initialize.set_model_parallel_rank(0)
+ merged_model = get_model(model_type)
+
+ # Build and load partitions.
+ partitions = []
+ iteration = 0
+ args.model_parallel_size = orig_model_parallel_size
+ tokenizer = rebuild_tokenizer(args)
+ mpu.initialize.set_model_parallel_world_size(args.model_parallel_size)
+ for rank in range(args.model_parallel_size):
+ mpu.initialize.set_model_parallel_rank(rank)
+ checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
+ print('> loading {} ...'.format(checkpoint_name))
+ model_ = get_model(model_type)
+ sd = torch.load(checkpoint_name, map_location='cpu')
+ model_.load_state_dict(sd['model'])
+ partitions.append(model_)
+
+
+ # Parameter generators so we can loop through them semiltaneouly.
+ merged_params_gen = merged_model.named_parameters()
+ partitions_params_gen = [partition.named_parameters()
+ for partition in partitions]
+ while True:
+ try:
+
+ # Get the params and check names.
+ name, merged_param = next(merged_params_gen)
+ print(' > working on {} ...'.format(name))
+ print(' merged type: {}, size: {}'.format(
+ merged_param.dtype, list(merged_param.size())))
+ partitions_param = []
+ for rank, partition_params_gen in enumerate(partitions_params_gen):
+ partition_name, partition_param = next(partition_params_gen)
+ assert partition_name == name
+ partitions_param.append(partition_param)
+ print(' partition {} type: {}, size: {}'.format(
+ rank, partition_param.dtype, list(partition_param.size())))
+
+ # For the non-parallel parameters, simply copy the rank 0 values.
+ if not hasattr(merged_param, 'model_parallel'):
+ print(' none-parallel parameter, simple copy from rank 0')
+ with torch.no_grad():
+ merged_param.data.copy_(partitions_param[0].data)
+ # For parallel parameters, merge the values
+ else:
+ print(' parallel parameter merge with stride {} along '
+ 'dimention {}'.format(merged_param.stride,
+ merged_param.partition_dim))
+ merge_partitions(merged_param,
+ partitions_param,
+ merged_param.partition_dim,
+ merged_param.partition_stride)
+
+ except StopIteration:
+ break
+
+
+ # Save the model.
+ args.model_parallel_size = 1
+ mpu.initialize.set_model_parallel_rank(0)
+ sd = {}
+ sd['model'] = merged_model.state_dict_for_save_checkpoint()
+ sd['iteration'] = iteration
+ merged_path = os.path.join(args.load, 'merged')
+ checkpoint_name = get_checkpoint_name(merged_path, iteration)
+ ensure_directory_exists(checkpoint_name)
+ print('> saving merged model to {}'.format(checkpoint_name))
+ torch.save(sd, checkpoint_name)
+
+ print('done :-)')
+
+
+if __name__ == '__main__':
+
+ main()
diff --git a/megatron_lm/tools/openwebtext/README.md b/megatron_lm/tools/openwebtext/README.md
new file mode 100644
index 0000000..db55e10
--- /dev/null
+++ b/megatron_lm/tools/openwebtext/README.md
@@ -0,0 +1,46 @@
+The following steps show how to prepare training dataset to train the mode.
+
+# Libraries to install
+
+```
+ pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract
+ git clone https://github.com/mattilyra/LSH
+ cd LSH
+ python setup.py install
+```
+
+# Download the dataset
+
+1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
+2. Remove blacklisted URLs.
+```
+python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
+```
+3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
+
+4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
+
+# Prepare the data for GPT-2 training:
+
+1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
+```
+python cleanup_dataset.py <input data file> <output cleaned data filename>
+```
+2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset.
+```
+python find_duplicates.py <input cleaned data file> <output possible duplicate urls filename>
+```
+3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
+```
+python group_duplicate_urls.py <possible duplicate urls file> <output file containing similar urls>
+```
+4. Remove similar documents that were detected in the last step.
+```
+python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data>
+```
+
+5. Shuffle the dataset.
+```
+shuf <cleaned deduped data file> -o train_data.json
+```
+
diff --git a/megatron_lm/tools/openwebtext/blacklist_urls.py b/megatron_lm/tools/openwebtext/blacklist_urls.py
new file mode 100644
index 0000000..d9c6f20
--- /dev/null
+++ b/megatron_lm/tools/openwebtext/blacklist_urls.py
@@ -0,0 +1,312 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import glob
+import re
+import time
+import tldextract
+import sys
+
+
+# List of the domains to blacklist.
+domain_blacklist = set([
+ '500px',
+ 'aapks',
+ 'akamaihd',
+ 'amazon',
+ 'apple',
+ 'artifactfire',
+ 'artstation',
+ 'awwni',
+ 'bandcamp',
+ 'battleforthenet',
+ 'coinscalendar',
+ 'dailymotion',
+ 'deviantart',
+ 'discord',
+ 'discordapp',
+ 'dlapkandroid',
+ 'dropbox',
+ 'e621',
+ 'ebay',
+ 'edealinfo',
+ 'erome',
+ 'eroshare',
+ 'explosm',
+ 'facebook',
+ 'fbcdn',
+ 'flickr',
+ 'furaffinity',
+ 'futhead',
+ 'gatopardo',
+ 'gfycat',
+ 'gifsound',
+ 'gifsoup',
+ 'giphy',
+ 'github',
+ 'google',
+ 'gunprime',
+ 'gyazo',
+ 'hotdealstar',
+ 'imagefap',
+ 'imageshack',
+ 'imgflip',
+ 'imgur',
+ 'instagram',
+ 'karmadecay',
+ 'kryptocal',
+ 'kym-cdn',
+ 'liveleak',
+ 'livememe',
+ 'lmgtfy',
+ 'magaimg',
+ 'memegenerator',
+ 'minorplanetcenter',
+ 'minus',
+ 'mobafire',
+ 'morejpeg',
+ 'nocookie',
+ 'pcpartpicker',
+ 'photobucket',
+ 'pinimg',
+ 'pinterest',
+ 'pixiv',
+ 'pornhub',
+ 'prntscr',
+ 'puu',
+ 'qkme',
+ 'quickmeme',
+ 'radd',
+ 'redd',
+ 'reddit',
+ 'reddit-stream',
+ 'redditlog',
+ 'redditmedia',
+ 'reddituploads',
+ 'redtube',
+ 'reupp',
+ 'reverb',
+ 'roanoke',
+ 'rollingstone',
+ 'sli',
+ 'soundcloud',
+ 'soundgasm',
+ 'spankbang',
+ 'spotify',
+ 'strawpoll',
+ 'streamable',
+ 'timeanddate',
+ 'tinypic',
+ 'touhouradio',
+ 'tumblr',
+ 'twimg',
+ 'twitch',
+ 'twitter',
+ 'vid',
+ 'vimeo',
+ 'vine',
+ 'vkaao',
+ 'vocaroo',
+ 'voyagefusion',
+ 'walmart',
+ 'wciu',
+ 'wikimedia',
+ 'wikipedia',
+ 'xhamster',
+ 'xkcd',
+ 'xvideos',
+ 'youtu',
+ 'youtube',
+ 'youtubedoubler',
+ 'ytimg',
+ 'zillexplorer',
+])
+
+def domain_is_in_blacklist(url):
+ domain = tldextract.extract(url).domain
+ return domain in domain_blacklist
+
+
+# List of extentions to blacklist.
+extentions_blacklist = (
+ '.3gp',
+ '.7z'
+ '.ai',
+ '.aif',
+ '.apk',
+ '.app',
+ '.avi',
+ '.bin',
+ '.bmp',
+ '.bz2',
+ '.css',
+ '.csv',
+ '.dat',
+ '.deb',
+ '.dmg',
+ '.doc',
+ '.docx',
+ '.exe',
+ '.gif',
+ '.gifv',
+ '.gz',
+ '.iso',
+ '.jar',
+ '.jpeg',
+ '.jpg',
+ '.js',
+ '.log',
+ '.mid',
+ '.midi',
+ '.mkv',
+ '.mov',
+ '.mp3',
+ '.mp4',
+ '.mpeg',
+ '.mpg',
+ '.ogg',
+ '.ogv',
+ '.otf',
+ '.pdf',
+ '.pkg',
+ '.png',
+ '.pps',
+ '.ppt',
+ '.pptx',
+ '.psd',
+ '.py',
+ '.qt',
+ '.ram',
+ '.rar',
+ '.sql',
+ '.svg',
+ '.swf',
+ '.tar.gz',
+ '.tar',
+ '.tgz',
+ '.tiff',
+ '.ttf',
+ '.txt',
+ '.wav',
+ '.webm',
+ '.wma',
+ '.wmv',
+ '.xls',
+ '.xlsx',
+ '.xml',
+ '.xz',
+ '.zip',
+)
+
+def extention_is_in_blacklist(url):
+ if url.split('?')[0].lower().endswith(extentions_blacklist):
+ return True
+ return False
+
+
+# Malformed urls.
+# This function is adapted from:
+# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
+url_regex = re.compile(
+ r'^(?:http)s?://' # http:// or https://
+ r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
+ r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
+ r'(?::\d+)?' # optional port
+ r'(?:/?|[/?]\S+)$', re.IGNORECASE)
+def url_is_malformed(url):
+ return re.match(url_regex, url) is None
+
+
+def print_progress(prefix, start_time, urls_counter,
+ domain_blacklist_counter,
+ extention_blacklist_counter,
+ short_url_counter, malformed_url_counter,
+ duplicate_url_counter):
+ string = prefix + ' | '
+ string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
+ string += 'number of urls: {} | '.format(urls_counter)
+ string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
+ string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
+ string += 'short urls (<=8): {} | '.format(short_url_counter)
+ string += 'malformed urls: {} | '.format(malformed_url_counter)
+ string += 'duplicate urls: {}'.format(duplicate_url_counter)
+ print(string, flush=True)
+
+
+if __name__ == '__main__':
+
+
+ print('remove blacklisted urls ..')
+
+ # Path to the url files.
+ path = sys.argv[1]
+ # Output url file.
+ output = sys.argv[2]
+
+ # Get the list of url files.
+ files = glob.glob(path + '/*.txt')
+ print('> found {} files'.format(len(files)))
+
+ urls = set()
+ urls_counter = 0
+ domain_blacklist_counter = 0
+ extention_blacklist_counter = 0
+ short_url_counter = 0
+ malformed_url_counter = 0
+ duplicate_url_counter = 0
+ start_time = time.time()
+ for filename in files:
+ with open(filename, 'r') as f:
+ for line in f:
+ url = line.strip()
+ urls_counter += 1
+ if domain_is_in_blacklist(url):
+ print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
+ domain_blacklist_counter += 1
+ elif extention_is_in_blacklist(url):
+ print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
+ extention_blacklist_counter += 1
+ elif len(url) <= 8:
+ print('[SHORT URL]: {}'.format(url), flush=True)
+ short_url_counter += 1
+ elif url_is_malformed(url):
+ print('[MALFORMED URL]: {}'.format(url), flush=True)
+ malformed_url_counter += 1
+ elif url in urls:
+ print('[DUPLICATE URL]: {}'.format(url), flush=True)
+ duplicate_url_counter += 1
+ else:
+ urls.add(url)
+ if urls_counter % 100000 == 0:
+ print_progress('PROGRESS', start_time, urls_counter,
+ domain_blacklist_counter,
+ extention_blacklist_counter,
+ short_url_counter, malformed_url_counter,
+ duplicate_url_counter)
+
+ print_progress('FINAL', start_time, urls_counter,
+ domain_blacklist_counter,
+ extention_blacklist_counter,
+ short_url_counter, malformed_url_counter,
+ duplicate_url_counter)
+
+ # Write the final set of urls.
+ print('> writing cleaned up url list to {}'.format(output))
+ with open(output, 'w') as f:
+ for url in urls:
+ f.write(url + '\n')
+
+ print('done :-)')
diff --git a/megatron_lm/tools/openwebtext/cleanup_dataset.py b/megatron_lm/tools/openwebtext/cleanup_dataset.py
new file mode 100644
index 0000000..77dfbf7
--- /dev/null
+++ b/megatron_lm/tools/openwebtext/cleanup_dataset.py
@@ -0,0 +1,115 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import ftfy
+import json
+from langdetect import detect
+import numpy as np
+import time
+import os
+import sys
+
+from tokenizer import Tokenizer
+
+MIN_DOCUMENT_LENGHT = 128
+
+
+def print_progress(prefix, start_time, num_docs, num_fixed_text,
+ num_non_english_docs, chars_non_english_docs,
+ num_small_docs, chars_small_docs):
+
+ string = prefix + ' | '
+ string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
+ string += 'documents: {} | '.format(num_docs)
+ string += 'fixed text: {} | '.format(num_fixed_text)
+ string += 'non-english: {} | '.format(num_non_english_docs)
+ string += 'non-english chars: {} | '.format(chars_non_english_docs)
+ string += 'small docs: {} | '.format(num_small_docs)
+ string += 'small docs chars: {}'.format(chars_small_docs)
+ print(string, flush=True)
+
+
+def filter_corpus(filename, out_filename, print_interval=10000):
+
+ print(' > filtering {}'.format(filename))
+
+ tokenizer = Tokenizer(cache_dir='./cache')
+
+ num_docs = 0
+ num_written_docs = 0
+ num_small_docs = 0
+ num_fixed_text = 0
+ num_non_english_docs = 0
+ chars_non_english_docs = 0
+ chars_small_docs = 0
+ start_time = time.time()
+ with open(out_filename, 'wb') as f:
+ with open(filename, 'r') as fin:
+ for line in fin:
+ try:
+ num_docs += 1
+ myjson = json.loads(line)
+ # Fix text
+ text = ftfy.fix_text(myjson['text'])
+ if text != myjson['text']:
+ num_fixed_text += 1
+ myjson['text'] = text
+ # Detect language.
+ if detect(text) != 'en':
+ print('[non-english text]', myjson)
+ num_non_english_docs += 1
+ chars_non_english_docs += len(text)
+ continue
+ # On average each token is 5 characters so 8 is an
+ # upper bound.
+ if len(text) < (8 * MIN_DOCUMENT_LENGHT):
+ tokens = tokenizer.tokenize_document(text)
+ if len(tokens) < MIN_DOCUMENT_LENGHT:
+ print('[small document, skipping]:', myjson)
+ num_small_docs += 1
+ chars_small_docs += len(text)
+ continue
+ myjson = json.dumps(myjson, ensure_ascii=False)
+ f.write(myjson.encode('utf-8'))
+ f.write('\n'.encode('utf-8'))
+ num_written_docs += 1
+ if num_docs % print_interval == 0:
+ print_progress('[PROGRESS]', start_time, num_docs,
+ num_fixed_text, num_non_english_docs,
+ chars_non_english_docs,
+ num_small_docs, chars_small_docs)
+ except Exception as e:
+ print(' skipping ', line, e)
+
+ print_progress('[FINAL]', start_time, num_docs,
+ num_fixed_text, num_non_english_docs,
+ chars_non_english_docs,
+ num_small_docs, chars_small_docs)
+
+
+if __name__ == '__main__':
+
+ print('building gpt2 dataset ...')
+
+ input_filename = sys.argv[1]
+ output_filename = sys.argv[2]
+
+ print('will be reading {}'.format(input_filename))
+ print('and will write the results to {}'.format(output_filename))
+
+ filter_corpus(input_filename, output_filename)
+
+
diff --git a/megatron_lm/tools/openwebtext/find_duplicates.py b/megatron_lm/tools/openwebtext/find_duplicates.py
new file mode 100644
index 0000000..455f43a
--- /dev/null
+++ b/megatron_lm/tools/openwebtext/find_duplicates.py
@@ -0,0 +1,100 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import itertools
+import json
+from lsh import cache, minhash
+import time
+import sys
+
+
+# This function is adapted from:
+# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
+def shingles(text, char_ngram=5):
+ return set(text[head:head + char_ngram]
+ for head in range(0, len(text) - char_ngram))
+
+
+# This function is adapted from:
+# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
+def jaccard(set_a, set_b):
+ intersection = set_a & set_b
+ union = set_a | set_b
+ return len(intersection) / len(union)
+
+
+if __name__ == '__main__':
+
+ print('finding possible duplicate content ...')
+
+ input = sys.argv[1]
+ output = sys.argv[2]
+
+ hasher = minhash.MinHasher(seeds=100, char_ngram=5, hashbytes=4)
+ lshcache = cache.Cache(bands=10, hasher=hasher)
+
+ counter = 0
+ url_doc = {}
+ start_time = time.time()
+ with open(input, 'r') as f:
+ for line in f:
+ try:
+ myjson = json.loads(line)
+ url = myjson['url']
+ text = myjson['text']
+ counter += 1
+ url_doc[url] = text
+ lshcache.add_fingerprint(hasher.fingerprint(text), url)
+ except Exception as e:
+ print('Error:', e)
+ if counter % 10000 == 0:
+ print(' [read]> processed {} documents in {:.2f} seconds ...'.
+ format(counter, time.time() - start_time), flush=True)
+
+ counter = 0
+ start_time = time.time()
+ deduped = 0
+ with open(output, 'wb') as f:
+ for b in lshcache.bins:
+ for bucket_id in b:
+ if len(b[bucket_id]) > 1:
+ items = list(b[bucket_id])
+ main_url = items[0]
+ main_dhingles = shingles(url_doc[main_url])
+ remove_urls = []
+ for i in range(1, len(items)):
+ counter += 1
+ other_url= items[i]
+ other_shingles = shingles(url_doc[other_url])
+ try:
+ jaccard_sim = jaccard(main_dhingles, other_shingles)
+ except Exception as e:
+ print('Error:', e)
+ if jaccard_sim > 0.5:
+ remove_urls.append({other_url: jaccard_sim})
+ deduped += 1
+ if counter % 10000 == 0:
+ print(' [write]> processed {} documents in {:.2f} '
+ 'seoncds and deduped {} documents ...'.
+ format(counter, time.time() - start_time,
+ deduped), flush=True)
+ if len(remove_urls) > 0:
+ myjson = json.dumps({main_url: remove_urls},
+ ensure_ascii=False)
+ f.write(myjson.encode('utf-8'))
+ f.write('\n'.encode('utf-8'))
+
+ print('done :-)')
diff --git a/megatron_lm/tools/openwebtext/group_duplicates_url.py b/megatron_lm/tools/openwebtext/group_duplicates_url.py
new file mode 100644
index 0000000..3622ce3
--- /dev/null
+++ b/megatron_lm/tools/openwebtext/group_duplicates_url.py
@@ -0,0 +1,90 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import time
+import sys
+
+
+def is_similar(jaccard_similarity):
+ return (js >= 0.9)
+
+
+if __name__ == '__main__':
+
+
+ print('grouping duplicate urls ...')
+
+ input = sys.argv[1]
+ output = sys.argv[2]
+
+ url_to_index = {}
+ index_to_urls = []
+ counter = 0
+ start_time = time.time()
+ with open(input, 'r') as f:
+ for line in f:
+ counter += 1
+ myjson = json.loads(line)
+ urls = []
+ for main_url in myjson.keys():
+ urls.append(main_url)
+ for value in myjson[main_url]:
+ for other_url, js in value.items():
+ if is_similar(js):
+ urls.append(other_url)
+ current_index = -1
+ other_indices = set()
+ for url in urls:
+ if url in url_to_index:
+ if current_index == -1:
+ current_index = url_to_index[url]
+ elif current_index != url_to_index[url]:
+ other_indices.add(url_to_index[url])
+ if current_index == -1:
+ current_index = len(index_to_urls)
+ index_to_urls.append(set())
+ for url in urls:
+ url_to_index[url] = current_index
+ index_to_urls[current_index].add(url)
+ for index in other_indices:
+ for url in index_to_urls[index]:
+ index_to_urls[current_index].add(url)
+ url_to_index[url] = current_index
+ index_to_urls[index] = None
+
+ if counter % 100000 == 0:
+ print(' > processed {} lines in {} seconds ...'.format(
+ counter, time.time() - start_time))
+
+
+ total_remove = 0
+ total_remain = 0
+ for urls in index_to_urls:
+ if urls is not None:
+ if len(urls) > 1:
+ total_remove += (len(urls) - 1)
+ total_remain += 1
+ print('out of {} urls, only {} are unique and {} should be removed'.format(
+ total_remove+total_remain, total_remain, total_remove))
+
+ with open(output, 'wb') as f:
+ for i, urls in enumerate(index_to_urls):
+ if urls is not None:
+ if len(urls) > 1:
+ myjson = json.dumps({str(i): list(urls)},
+ ensure_ascii=False)
+ f.write(myjson.encode('utf-8'))
+ f.write('\n'.encode('utf-8'))
diff --git a/megatron_lm/tools/openwebtext/merge_jsons.py b/megatron_lm/tools/openwebtext/merge_jsons.py
new file mode 100644
index 0000000..ad5ed31
--- /dev/null
+++ b/megatron_lm/tools/openwebtext/merge_jsons.py
@@ -0,0 +1,55 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import glob
+import sys
+import json
+import argparse
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--json_path", type=str, default=".",
+ help="path where all the json files are located")
+
+ parser.add_argument("--output_file", type=str, default="merged_output.json",
+ help="filename where the merged json should go")
+
+ args = parser.parse_args()
+
+ json_path = args.json_path
+ out_file = args.output_file
+
+ json_files = glob.glob(json_path + '/*.json')
+
+ counter = 0
+
+ with open(out_file, 'w') as outfile:
+ for fname in json_files:
+ counter += 1
+
+ if counter % 1024 == 0:
+ print("Merging at ", counter, flush=True)
+
+ with open(fname, 'r') as infile:
+ for row in infile:
+ each_row = json.loads(row)
+ outfile.write(row)
+
+
+ print("Merged file", out_file, flush=True)
+
+
diff --git a/megatron_lm/tools/openwebtext/remove_group_duplicates.py b/megatron_lm/tools/openwebtext/remove_group_duplicates.py
new file mode 100644
index 0000000..b6f580e
--- /dev/null
+++ b/megatron_lm/tools/openwebtext/remove_group_duplicates.py
@@ -0,0 +1,69 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import time
+import sys
+
+
+if __name__ == '__main__':
+
+ url_filename = sys.argv[1]
+ data_filename = sys.argv[2]
+ output_filename = sys.argv[3]
+
+ urls = set()
+ with open(url_filename, 'r') as f:
+ for line in f:
+ myjson = json.loads(line)
+ for key in myjson:
+ this_urls = myjson[key]
+ for i in range(1, len(this_urls)):
+ urls.add(this_urls[i])
+ print('will be removing {} urls'.format(len(urls)), flush=True)
+
+ written_docs = 0
+ removed_docs = 0
+ removed_chars = 0
+ start_time = time.time()
+ with open(output_filename, 'wb') as fout:
+ with open(data_filename, 'r') as fin:
+ for line in fin:
+ try:
+ myjson = json.loads(line)
+ url = myjson['url']
+ if url in urls:
+ print('removing', myjson)
+ removed_docs += 1
+ removed_chars += len(myjson['text'])
+ continue
+ myjson = json.dumps(myjson, ensure_ascii=False)
+ fout.write(myjson.encode('utf-8'))
+ fout.write('\n'.encode('utf-8'))
+ written_docs += 1
+ if written_docs % 10000 == 0:
+ print(' [PROCESSED] time (s): {:.2f} | written: {} '
+ '| removed: {} (char: {})'.format(
+ time.time() - start_time,
+ written_docs, removed_docs, removed_chars))
+ except Exception as e:
+ print('[SKIPPING]', line, e)
+
+ print(' [PROCESSED] time (s): {:.2f} | written: {} '
+ '| removed: {} (char: {})'.format(
+ time.time() - start_time,
+ written_docs, removed_docs, removed_chars))
+ print('done :-)')
diff --git a/megatron_lm/tools/preprocess_data.py b/megatron_lm/tools/preprocess_data.py
new file mode 100644
index 0000000..c5f1392
--- /dev/null
+++ b/megatron_lm/tools/preprocess_data.py
@@ -0,0 +1,200 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Processing data for pretraining."""
+
+import argparse
+import json
+import multiprocessing
+import os
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
+ os.path.pardir)))
+import time
+
+import torch
+try:
+ import nltk
+ nltk_available = True
+except ImportError:
+ nltk_available = False
+
+from megatron.tokenizer import build_tokenizer
+from megatron.data import indexed_dataset
+
+
+# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
+class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
+
+ _period_context_fmt = r"""
+ \S* # some word material
+ %(SentEndChars)s # a potential sentence ending
+ \s* # <-- THIS is what I changed
+ (?=(?P<after_tok>
+ %(NonWord)s # either other punctuation
+ |
+ (?P<next_tok>\S+) # <-- Normally you would have \s+ here
+ ))"""
+
+class IdentitySplitter(object):
+ def tokenize(self, *text):
+ return text
+
+class Encoder(object):
+ def __init__(self, args):
+ self.args = args
+
+ def initializer(self):
+ # Use Encoder class as a container for global data
+ Encoder.tokenizer = build_tokenizer(self.args)
+ if self.args.split_sentences:
+ if not nltk_available:
+ print("NLTK is not available to split sentences.")
+ exit()
+ splitter = nltk.load("tokenizers/punkt/english.pickle")
+ if self.args.keep_newlines:
+ # this prevents punkt from eating newlines after sentences
+ Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
+ train_text = splitter._params,
+ lang_vars = CustomLanguageVars())
+ else:
+ Encoder.splitter = splitter
+
+ else:
+ Encoder.splitter = IdentitySplitter()
+
+ def encode(self, json_line):
+ data = json.loads(json_line)
+ ids = {}
+ for key in self.args.json_keys:
+ text = data[key]
+ doc_ids = []
+ for sentence in Encoder.splitter.tokenize(text):
+ sentence_ids = Encoder.tokenizer.tokenize(sentence)
+ if len(sentence_ids) > 0:
+ doc_ids.append(sentence_ids)
+ if self.args.append_eod:
+ doc_ids[-1].append(Encoder.tokenizer.eod)
+ ids[key] = doc_ids
+ return ids, len(json_line)
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ group = parser.add_argument_group(title='input data')
+ group.add_argument('--input', type=str, required=True,
+ help='Path to input JSON')
+ group.add_argument('--json-keys', nargs='+', default=['text'],
+ help='space separate listed of keys to extract from json')
+ group.add_argument('--split-sentences', action='store_true',
+ help='Split documents into sentences.')
+ group.add_argument('--keep-newlines', action='store_true',
+ help='Keep newlines between sentences when splitting.')
+
+ group = parser.add_argument_group(title='tokenizer')
+ group.add_argument('--tokenizer-type', type=str, required=True,
+ choices=['BertWordPieceLowerCase','BertWordPieceCase',
+ 'GPT2BPETokenizer'],
+ help='What type of tokenizer to use.')
+ group.add_argument('--vocab-file', type=str, default=None,
+ help='Path to the vocab file')
+ group.add_argument('--merge-file', type=str, default=None,
+ help='Path to the BPE merge file (if necessary).')
+ group.add_argument('--append-eod', action='store_true',
+ help='Append an <eod> token to the end of a document.')
+
+
+ group = parser.add_argument_group(title='output data')
+ group.add_argument('--output-prefix', type=str, required=True,
+ help='Path to binary output file without suffix')
+ group.add_argument('--dataset-impl', type=str, default='mmap',
+ choices=['lazy', 'cached', 'mmap'])
+
+ group = parser.add_argument_group(title='runtime')
+ group.add_argument('--workers', type=int, default=1,
+ help='Number of worker processes to launch')
+ group.add_argument('--log-interval', type=int, default=100,
+ help='Interval between progress updates')
+ args = parser.parse_args()
+ args.keep_empty = False
+
+ if args.tokenizer_type.lower().startswith('bert'):
+ if not args.split_sentences:
+ print("Bert tokenizer detected, are you sure you don't want to split sentences?")
+
+ # some default/dummy values for the tokenizer
+ args.rank = 0
+ args.make_vocab_size_divisible_by = 128
+ args.model_parallel_size = 1
+
+ return args
+
+def main():
+ args = get_args()
+ startup_start = time.time()
+
+ print("Opening", args.input)
+ fin = open(args.input, 'r', encoding='utf-8')
+
+ if nltk_available and args.split_sentences:
+ nltk.download("punkt", quiet=True)
+
+ encoder = Encoder(args)
+ tokenizer = build_tokenizer(args)
+ pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
+ encoded_docs = pool.imap(encoder.encode, fin, 25)
+ #encoded_docs = map(encoder.encode, fin)
+
+ level = "document"
+ if args.split_sentences:
+ level = "sentence"
+
+ print(f"Vocab size: {tokenizer.vocab_size}")
+ print(f"Output prefix: {args.output_prefix}")
+ output_bin_files = {}
+ output_idx_files = {}
+ builders = {}
+ for key in args.json_keys:
+ output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
+ key, level)
+ output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
+ key, level)
+ builders[key] = indexed_dataset.make_builder(output_bin_files[key],
+ impl=args.dataset_impl,
+ vocab_size=tokenizer.vocab_size)
+
+ startup_end = time.time()
+ proc_start = time.time()
+ total_bytes_processed = 0
+ print("Time to startup:", startup_end - startup_start)
+
+ for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
+ total_bytes_processed += bytes_processed
+ for key, sentences in doc.items():
+ for sentence in sentences:
+ builders[key].add_item(torch.IntTensor(sentence))
+ builders[key].end_document()
+ if i % args.log_interval == 0:
+ current = time.time()
+ elapsed = current - proc_start
+ mbs = total_bytes_processed/elapsed/1024/1024
+ print(f"Processed {i} documents",
+ f"({i/elapsed} docs/s, {mbs} MB/s).",
+ file=sys.stderr)
+
+ for key in args.json_keys:
+ builders[key].finalize(output_idx_files[key])
+
+if __name__ == '__main__':
+ main()