Commit aaad0c9b authored by Torge Berckmann's avatar Torge Berckmann
Browse files

Added LM for comparison

parent 616b6dc2
......@@ -17,6 +17,7 @@ from fairseq.models.transformer.transformer_config import (
DEFAULT_MIN_PARAMS_TO_WRAP,
)
from .trans_xl_base import TransformerXLBase
from .mem_transformer import MemTransformerLM
class SimpleLSTMEncoder(FairseqEncoder):
......@@ -300,54 +301,6 @@ def lengths_to_mask(lengths, max_length, flip=False):
return mask.to(lengths.device)
init_std = 0.1
proj_init_std = 0.01
def init_weight(weight):
nn.init.uniform_(weight, -init_std, init_std)
def init_bias(bias):
nn.init.constant_(bias, 0.0)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, init_std)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
init_bias(m.r_bias)
class ContextXLEncoder(FairseqEncoder):
def __init__(
self, args, dictionary, embed_tokens
......@@ -355,28 +308,15 @@ class ContextXLEncoder(FairseqEncoder):
super().__init__(dictionary)
self.args = args
d_embed = args.encoder_embed_dim + 128
n_token = embed_tokens.num_embeddings
n_layer = args.encoder_layers
n_head = args.encoder_attention_heads
d_model = args.encoder_embed_dim
d_head = d_model // n_head
d_inner = args.encoder_ffn_embed_dim
dropout = args.dropout
dropatt = args.attention_dropout
tgt_len, mem_len, ext_len = 0, 0, 0
torch.set_printoptions(threshold=10_000)
print("intial_num_tokens", n_token)
self.encoder = MemTransformerLM( n_token, n_layer, n_head, d_model, d_head, d_inner,
dropout, dropatt, tie_weight=True, d_embed=d_embed,
div_val=1, tie_projs=[False], pre_lnorm=False,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1)
self.encoder.apply(weights_init)
self.encoder.word_emb.apply(weights_init)
self.encoder = TransformerXLBase(n_token, n_layer, n_head, d_model, d_inner, dropout, dropatt, as_encoder=True)
def forward(
self,
......@@ -395,10 +335,8 @@ class ContextXLEncoder(FairseqEncoder):
bsz, seq_len = src_tokens.shape
# dec_inp is "qlen, bsz"
dec_inp = src_tokens.transpose(0,1)
core_out, new_mems = self.encoder._forward(dec_inp)
core_out, new_mems = self.encoder.fwd(src_tokens, src_lengths)
ones_mask = lengths_to_mask(src_lengths, seq_len, flip=self.args.left_pad_source)
bool_mask = ones_mask.bool().logical_not()
......
......@@ -498,7 +498,7 @@ class MemTransformerLM(nn.Module):
tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1):
sample_softmax=-1, as_encoder=False):
super(MemTransformerLM, self).__init__()
self.n_token = n_token
......@@ -507,6 +507,7 @@ class MemTransformerLM(nn.Module):
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.as_encoder = as_encoder
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
div_val=div_val)
......@@ -659,7 +660,8 @@ class MemTransformerLM(nn.Module):
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
# In encoder there is no mask
dec_attn_mask = torch.zeros_like(dec_attn_mask)
if self.as_encoder:
dec_attn_mask = torch.zeros_like(dec_attn_mask)
hids = []
if self.attn_type == 0: # default
......
import torch
from torch import nn
from fairseq.models import (
FairseqLanguageModel,
FairseqDecoder,
register_model,
register_model_architecture
)
from .mem_transformer import MemTransformerLM
init_std = 0.1
proj_init_std = 0.01
def init_weight(weight):
nn.init.uniform_(weight, -init_std, init_std)
def init_bias(bias):
nn.init.constant_(bias, 0.0)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, init_std)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
init_bias(m.r_bias)
class TransformerXLBase(nn.Module):
def __init__(self, n_token, n_layer, n_head, d_model, d_inner, dropout, dropatt, as_encoder=False):
super().__init__()
d_head = d_model // n_head
d_embed = d_model + 128
tgt_len, mem_len, ext_len = 0, 0, 0
tgt_len, mem_len, ext_len = 0, 0, 0
self.txl = MemTransformerLM( n_token, n_layer, n_head, d_model, d_head, d_inner,
dropout, dropatt, tie_weight=True, d_embed=d_embed,
div_val=1, tie_projs=[False], pre_lnorm=False,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1, as_encoder=as_encoder)
self.txl.apply(weights_init)
self.txl.word_emb.apply(weights_init)
def fwd(self, src_tokens, src_lengths):
# dec_inp is "qlen, bsz"
dec_inp = src_tokens.transpose(0,1)
core_out, new_mems = self.txl._forward(dec_inp)
return core_out, new_mems
def Linear(in_features, out_features, bias=True, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.uniform_(-0.1, 0.1)
if bias:
m.bias.data.uniform_(-0.1, 0.1)
return m
# Basis for language model
class ContextXLDecoder(FairseqDecoder):
def __init__(self, args, dictionary, num_embeds):
super().__init__(dictionary)
n_layer = args.decoder_layers
n_head = args.decoder_attention_heads
d_model = args.decoder_embed_dim
d_inner = args.decoder_ffn_embed_dim
dropout = args.dropout
dropatt = args.attention_dropout
self.txl = TransformerXLBase(num_embeds, n_layer, n_head, d_model, d_inner, dropout, dropatt, as_encoder=False)
self.fc_out = Linear(d_model, num_embeds)
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
core_out, new_mems = self.txl.fwd(prev_output_tokens, kwargs["src_lengths"])
cout_t = core_out.transpose(0,1)
x = self.fc_out(cout_t)
return x, None
# Input (prev tokens): [b x seqlen]
# output: [b x seqlen x vocablen]
@register_model("contextxl_lm")
class ContextXLLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@staticmethod
def add_args(parser):
# Models can override this method to add new command-line arguments.
# Here we'll add some new command-line arguments to configure dropout
# and the dimensionality of the embeddings and hidden states.
parser.add_argument(
'--decoder-embed-dim', type=int, metavar='N',
help='dimensionality of the decoder embeddings',
)
parser.add_argument(
'--decoder_ffn_embed_dim', type=int, metavar='N',
help='dimensionality of the decoder feed-forward layer',
)
parser.add_argument(
'--dropout', type=float, default=0.1,
help='dropout probability',
)
parser.add_argument(
'--decoder_layers', type=int, metavar='N',
help='number of layers in decoder',
)
parser.add_argument(
'--decoder_attention_heads', type=int, metavar='N',
help='dimensionality of the decoder hidden state',
)
parser.add_argument(
'--attention_dropout', type=float, default=0.1,
help='decoder dropout probability',
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
num_embeds = len(task.source_dictionary)
decoder = ContextXLDecoder(
args, task.target_dictionary, num_embeds
)
return cls(decoder)
@register_model_architecture('contextxl_lm', 'arch_contextxl_lm')
def base_architecture(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.dropout = getattr(args, "dropout", 0.1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment