Commit 61d10100 authored by Myle Ott's avatar Myle Ott
Browse files

Merge branch 'oss-master' into seq_task

parents fb4dda92 6d1233fa
......@@ -9,6 +9,9 @@ __pycache__/
# C extensions
*.so
# macOS dir files
.DS_Store
# Distribution / packaging
.Python
env/
......
......@@ -28,7 +28,7 @@ Fairseq features:
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- sampling (unconstrained and top-k)
- large mini-batch training even on a single GPU via delayed updates
- fast half-precision floating point (FP16) training
- mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models-and-examples) for several benchmark
......@@ -39,7 +39,7 @@ translation and language modeling datasets.
# Requirements and Installation
* [PyTorch](http://pytorch.org/) version >= 1.0.0
* Python version >= 3.6
* Python version >= 3.5
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
Please follow the instructions here to install PyTorch: https://github.com/pytorch/pytorch#installation.
......
......@@ -74,12 +74,18 @@ Adding new models
.. autoclass:: fairseq.models.BaseFairseqModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqModel
.. autoclass:: fairseq.models.FairseqEncoderDecoderModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqEncoderModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqLanguageModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqMultiModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqEncoder
:members:
.. autoclass:: fairseq.models.CompositeEncoder
......
......@@ -2,7 +2,7 @@ Modules
=======
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
.. automodule:: fairseq.modules
:members:
......
......@@ -41,7 +41,7 @@ New plug-ins are *registered* through a set of ``@register`` function
decorators, for example::
@register_model('my_lstm')
class MyLSTM(FairseqModel):
class MyLSTM(FairseqEncoderDecoderModel):
(...)
Once registered, new plug-ins can be used with the existing :ref:`Command-line
......
......@@ -32,7 +32,7 @@ Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
command-line tool to create the dictionaries. While this tool is primarily
intended for sequence-to-sequence problems, we're able to reuse it here by
treating the label as a "target" sequence of length 1. We'll also output the
preprocessed files in "raw" format using the ``--output-format`` option to
preprocessed files in "raw" format using the ``--dataset-impl`` option to
enhance readability:
.. code-block:: console
......@@ -40,7 +40,7 @@ enhance readability:
> fairseq-preprocess \
--trainpref names/train --validpref names/valid --testpref names/test \
--source-lang input --target-lang label \
--destdir names-bin --output-format raw
--destdir names-bin --dataset-impl raw
After running the above command you should see a new directory,
:file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
......
......@@ -2,9 +2,9 @@ Tutorial: Simple LSTM
=====================
In this tutorial we will extend fairseq by adding a new
:class:`~fairseq.models.FairseqModel` that encodes a source sentence with an
LSTM and then passes the final hidden state to a second LSTM that decodes the
target sentence (without attention).
:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
sentence with an LSTM and then passes the final hidden state to a second LSTM
that decodes the target sentence (without attention).
This tutorial covers:
......@@ -233,18 +233,18 @@ Once the model is registered we'll be able to use it with the existing
All registered models must implement the
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
models (i.e., any model with a single Encoder and Decoder), we can instead
implement the :class:`~fairseq.models.FairseqModel` interface.
implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
Create a small wrapper class in the same file and register it in fairseq with
the name ``'simple_lstm'``::
from fairseq.models import FairseqModel, register_model
from fairseq.models import FairseqEncoderDecoderModel, register_model
# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.
@register_model('simple_lstm')
class SimpleLSTMModel(FairseqModel):
class SimpleLSTMModel(FairseqEncoderDecoderModel):
@staticmethod
def add_args(parser):
......@@ -308,7 +308,7 @@ the name ``'simple_lstm'``::
# We could override the ``forward()`` if we wanted more control over how
# the encoder and decoder interact, but it's not necessary for this
# tutorial since we can inherit the default implementation provided by
# the FairseqModel base class, which looks like:
# the FairseqEncoderDecoderModel base class, which looks like:
#
# def forward(self, src_tokens, src_lengths, prev_output_tokens):
# encoder_out = self.encoder(src_tokens, src_lengths)
......
......@@ -6,7 +6,28 @@ This page includes pre-trained models from the paper [Understanding Back-Transla
Description | Dataset | Model | Test set(s)
---|---|---|---
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) | See NOTE in the archive
## Example usage
Interactive generation from the full ensemble via PyTorch Hub:
```
>>> import torch
>>> en2de_ensemble = torch.hub.load(
... 'pytorch/fairseq',
... 'transformer',
... model_name_or_path='transformer.wmt18.en-de',
... checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
... data_name_or_path='.',
... tokenizer='moses',
... aggressive_dash_splits=True,
... bpe='subword_nmt',
... )
>>> len(en2de_ensemble.models)
5
>>> print(en2de_ensemble.generate('Hello world!'))
Hallo Welt!
```
## Citation
```bibtex
......
......@@ -18,9 +18,10 @@ Let's assume the following for the code snippets in later sections to work
Pre-process and binarize the data with the MaskedLMDictionary and cross_lingual_lm task
```
```bash
# Ensure the output directory exists
mkdir -p monolingual_data/fairseq_processed
DATA_DIR=monolingual_data/fairseq_processed
mkdir -p "$DATA_DIR"
for lg in ar de en hi fr
do
......@@ -41,8 +42,8 @@ do
for stage in train test valid
sudo mv $stage.$lg-None.$lg.bin $stage.$lg.bin
sudo mv $stage.$lg-None.$lg.idx $stage.$lg.idx
sudo mv "$DATA_DIR/$stage.$lg-None.$lg.bin" "$stage.$lg.bin"
sudo mv "$DATA_DIR/$stage.$lg-None.$lg.idx" "$stage.$lg.idx"
done
......@@ -55,7 +56,7 @@ Use the following command to train the model on 5 languages.
```
fairseq-train \
--task cross_lingual_lm monolingual_data/processed \
--task cross_lingual_lm monolingual_data/fairseq_processed \
--save-dir checkpoints/mlm \
--max-update 2400000 --save-interval 1 --no-epoch-checkpoints \
--arch xlm_base \
......@@ -63,8 +64,8 @@ fairseq-train \
--lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \
--dropout 0.1 \
--criterion masked_lm_loss \
--max-tokens 2048 --tokens-per-sample 256 --no-bias-kv --attention-dropout 0.1 \
--lazy-load --seed 0 \
--max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \
--dataset-impl lazy --seed 0 \
--masked-lm-only \
--monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \
--ddp-backend=no_c10d
......
......@@ -4,11 +4,37 @@
Description | Parameters | Dataset | Model and Test set(s)
---|---:|---|---
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2)
## Example usage
Interactive generation via PyTorch Hub:
```
>>> import torch
>>> lm = torch.hub.load(
... 'pytorch/fairseq',
... 'transformer_lm',
... model_name_or_path='transformer_lm.wiki103.adaptive',
... data_name_or_path='./data-bin',
... tokenizer='moses',
... aggressive_dash_splits=True,
... no_escape=True,
... beam=1,
... sampling=True,
... sampling_topk=10,
... temperature=0.8,
... )
>>> lm.generate('Barack Obama', verbose=True)
```
Available models are listed in the ``hub_models()`` method in each model file, for example:
[transformer_lm.py](https://github.com/pytorch/fairseq/blob/master/fairseq/models/transformer_lm.py).
## Training a new model with the CLI tools
These scripts provide an example of pre-processing data for the Language Modeling task.
### prepare-wikitext-103.sh
......@@ -45,10 +71,8 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
# Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/transformer_wiki103/checkpoint_best.pt' \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024
```
Train a convolutional language model ([Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](conv_lm/README.md)):
```
# If it runs out of memory, try to reduce max-tokens and tokens-per-sample
......@@ -63,5 +87,4 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
# Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/fconv_wiki103/checkpoint_best.pt'
```
......@@ -73,8 +73,7 @@ mkdir -p $SAVE
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
data-bin/wmt16_en_de_bpe32k --fp16 --log-interval 100 --no-progress-bar \
--max-update 30000 --share-all-embeddings --optimizer adam \
--adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt \
--clip-norm 0.0 --weight-decay 0.0 \
--adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
--ddp-backend=no_c10d --max-tokens 3584 \
......@@ -99,8 +98,7 @@ mkdir -p $SAVE
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
data-bin/wmt14_en_fr --fp16 --log-interval 100 --no-progress-bar \
--max-update 30000 --share-all-embeddings --optimizer adam \
--adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt \
--clip-norm 0.0 --weight-decay 0.0 \
--adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
--ddp-backend=no_c10d --max-tokens 3584 \
......
......@@ -21,7 +21,6 @@ curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvz
and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token.
## Example usage
```
......
......@@ -9,9 +9,32 @@ Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) |
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) | See NOTE in the archive
## Example usage
## Example usage (torch.hub)
Interactive generation via PyTorch Hub:
```
>>> import torch
>>> en2de = torch.hub.load(
... 'pytorch/fairseq',
... 'transformer',
... model_name_or_path='transformer.wmt16.en-de',
... data_name_or_path='.',
... tokenizer='moses',
... aggressive_dash_splits=True,
... bpe='subword_nmt',
... )
>>> print(en2de.models[0].__class__)
<class 'fairseq.models.transformer.TransformerModel'>
>>> print(en2de.generate('Hello world!'))
Hallo Welt!
```
Available models are listed in the ``hub_models()`` method in each model file, for example:
[transformer.py](https://github.com/pytorch/fairseq/blob/master/fairseq/models/transformer.py).
## Example usage (CLI tools)
Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
```
......
......@@ -38,12 +38,12 @@ Once a model is trained, we can generate translations from different experts usi
For example, to generate from expert 0:
```
$ fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert 0
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert 0
```
## Evaluate
......@@ -63,20 +63,20 @@ $ for EXPERT in $(seq 0 2); do \
fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--buffer 500 --max-tokens 6000 ; \
--buffer-size 500 --max-tokens 6000 \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert $EXPERT \
--gen-expert $EXPERT ; \
done > wmt14-en-de.extra_refs.tok.gen.3experts
```
Finally use `scripts/score_moe.py` to compute pairwise BLUE and average oracle BLEU:
Finally use `score_moe.py` to compute pairwise BLUE and average oracle BLEU:
```
$ python scripts/score_moe.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
$ python examples/translation_moe/score.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
pairwise BLEU: 48.26
avg oracle BLEU: 49.50
#refs covered: 2.11
multi-reference BLEU (leave-one-out): 59.46
```
This matches row 3 from Table 7 in the paper.
......
......@@ -6,7 +6,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Scoring script for computing pairwise BLEU and oracle BLEU over a set of
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
candidate hypotheses.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
......@@ -16,9 +16,9 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
import argparse
from itertools import chain
import sys
import numpy as np
import random
import numpy as np
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
......@@ -37,6 +37,7 @@ def main():
print('pairwise BLEU: %.2f' % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
if args.ref:
_, _, refs = load_ref(args.ref)
if args.sys:
......@@ -154,19 +155,20 @@ def multi_ref(refs, hypos):
refs = list(zip(*refs))
hypos = list(zip(*hypos))
# compute average corpus BLEU
# compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
k = len(hypos)
m = len(refs)
concat_hypos = []
concat_refs = [[] for j in range(m - 1)]
for i in range(m):
concat_hypos.append([h for hs in hypos for h in hs])
rest = refs[:i] + refs[i+1:]
for j in range(m - 1):
concat_refs[j].extend(rest[j] * k)
concat_hypos = list(chain.from_iterable(concat_hypos))
bleu = corpus_bleu(concat_hypos, concat_refs)
print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
duplicated_refs = [
[ref for ref in refs_i for _ in range(k)]
for refs_i in refs
]
loo_bleus = []
for held_out_ref in range(m):
remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
assert len(remaining_refs) == m - 1
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
def intra_ref(refs):
......
......@@ -5,24 +5,138 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
from collections import OrderedDict
from typing import Union
import collections
import logging
import os
import re
import traceback
import shutil
import torch
from torch.serialization import default_restore_location
from fairseq import tasks
from fairseq.models import FairseqEncoder, FairseqDecoder
def load_checkpoint_to_cpu(path):
def save_checkpoint(args, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters
if args.no_save or not distributed_utils.is_master(args):
return
write_timer = meters.StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
}
if hasattr(save_checkpoint, 'best'):
extra_state.update({'best': save_checkpoint.best})
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
shutil.copyfile(checkpoints[0], cp)
write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
checkpoints[0], epoch, updates, write_timer.sum))
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer):
"""Load a checkpoint and restore the training iterator."""
# only one worker should attempt to create the required dir
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
if os.path.isabs(args.restore_file):
checkpoint_path = args.restore_file
else:
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(
checkpoint_path,
args.reset_optimizer,
args.reset_lr_scheduler,
eval(args.optimizer_overrides),
reset_meters=args.reset_meters,
)
if (
extra_state is not None
and 'best' in extra_state
and not args.reset_optimizer
and not args.reset_meters
):
save_checkpoint.best = extra_state['best']
if extra_state is not None and not args.reset_dataloader:
# restore iterator from checkpoint
itr_state = extra_state['train_iterator']
epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'])
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(epoch=0)
trainer.lr_step(epoch_itr.epoch)
return extra_state, epoch_itr
def load_checkpoint_to_cpu(path, arg_overrides=None):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
state = _upgrade_state_dict(state)
return state
......@@ -36,17 +150,20 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
ensemble, args, _task = _load_model_ensemble(filenames, arg_overrides, task)
return ensemble, args
def _load_model_ensemble(filenames, arg_overrides=None, task=None):
from fairseq import tasks
ensemble = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename)
state = load_checkpoint_to_cpu(filename, arg_overrides)
args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
if task is None:
task = tasks.setup_task(args)
......@@ -54,8 +171,7 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
model = task.build_model(args)
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
return ensemble, args, task
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
......@@ -127,6 +243,8 @@ def save_state(
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
from fairseq import models, registry, tasks
# add optimizer_history