Commit 70b02caf authored by alexeib's avatar alexeib
Browse files

address issue #867 (missing elmo-dropout)

parent 693cc3ca
......@@ -47,7 +47,7 @@ ${MT_DATA} is a processed machine translation dataset, e.g. WMT En2De.
```
$ python train.py ${MT_DATA} -a transformer_wmt_en_de_big --no-enc-token-positional-embeddings \
--elmo-affine --elmo-softmax --clip-norm 0 --fp16 --optimizer adam --lr 0.0007 \
--elmo-affine --clip-norm 0 --fp16 --optimizer adam --lr 0.0007 \
--label-smoothing 0.1 --ddp-backend no_c10d --dropout 0.3 --elmo-dropout 0.2 \
--distributed-port 12597 --distributed-world-size 128 --max-tokens 3584 --no-progress-bar \
--log-interval 100 --seed 1 --min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0 \
......
......@@ -112,6 +112,9 @@ class TransformerModel(FairseqModel):
help='if set, adds bos to input')
parser.add_argument('--elmo-affine', default=False, action='store_true',
help='if set, uses affine layer norm for elmo')
parser.add_argument('--elmo-dropout', type=float, metavar='D',
help='dropout probability for elmo hidden layers')
parser.add_argument('--decoder-embed-scale', type=float,
help='scaling factor for embeddings used in decoder')
parser.add_argument('--encoder-embed-scale', type=float,
......@@ -143,7 +146,8 @@ class TransformerModel(FairseqModel):
embedder = ElmoTokenEmbedder(models[0], dictionary.eos(), dictionary.pad(), add_bos=is_encoder,
remove_bos=is_encoder, combine_tower_states=is_encoder,
projection_dim=embed_dim, add_final_predictive=is_encoder,
add_final_context=is_encoder, affine_layer_norm=args.elmo_affine)
add_final_context=is_encoder, affine_layer_norm=args.elmo_affine,
weights_dropout=args.elmo_dropout)
return embedder
elif path.startswith('bilm:'):
lm_path = path[5:]
......@@ -1022,6 +1026,8 @@ def base_architecture(args):
args.decoder_embed_scale = getattr(args, 'decoder_embed_scale', None)
args.encoder_embed_scale = getattr(args, 'encoder_embed_scale', None)
args.elmo_dropout = getattr(args, 'elmo_dropout', 0.)
args.bilm_mask_last_state = getattr(args, 'bilm_mask_last_state', False)
args.bilm_add_bos = getattr(args, 'bilm_add_bos', False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment