deep_question_answering

Implementation of "Teaching Machines to Read and Comprehend" proposed by Google DeepMind
git clone https://esimon.eu/repos/deep_question_answering.git
Log | Files | Refs | README | LICENSE

attentive_reader.py (6956B)


      1 import theano
      2 from theano import tensor
      3 import numpy
      4 
      5 from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier
      6 from blocks.bricks.lookup import LookupTable
      7 from blocks.bricks.recurrent import LSTM
      8 
      9 from blocks.filter import VariableFilter
     10 from blocks.roles import WEIGHT
     11 from blocks.graph import ComputationGraph, apply_dropout, apply_noise
     12 
     13 def make_bidir_lstm_stack(seq, seq_dim, mask, sizes, skip=True, name=''):
     14     bricks = []
     15 
     16     curr_dim = [seq_dim]
     17     curr_hidden = [seq]
     18 
     19     hidden_list = []
     20     for k, dim in enumerate(sizes):
     21         fwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_fwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)]
     22         fwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_fwd_lstm_%d'%(name,k))
     23 
     24         bwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_bwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)]
     25         bwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_bwd_lstm_%d'%(name,k))
     26 
     27         bricks = bricks + [fwd_lstm, bwd_lstm] + fwd_lstm_ins + bwd_lstm_ins
     28 
     29         fwd_tmp = sum(x.apply(v) for x, v in zip(fwd_lstm_ins, curr_hidden))
     30         bwd_tmp = sum(x.apply(v) for x, v in zip(bwd_lstm_ins, curr_hidden))
     31         fwd_hidden, _ = fwd_lstm.apply(fwd_tmp, mask=mask)
     32         bwd_hidden, _ = bwd_lstm.apply(bwd_tmp[::-1], mask=mask[::-1])
     33         hidden_list = hidden_list + [fwd_hidden, bwd_hidden]
     34         if skip:
     35             curr_hidden = [seq, fwd_hidden, bwd_hidden[::-1]]
     36             curr_dim = [seq_dim, dim, dim]
     37         else:
     38             curr_hidden = [fwd_hidden, bwd_hidden[::-1]]
     39             curr_dim = [dim, dim]
     40 
     41     return bricks, hidden_list
     42 
     43 class Model():
     44     def __init__(self, config, vocab_size):
     45         question = tensor.imatrix('question')
     46         question_mask = tensor.imatrix('question_mask')
     47         context = tensor.imatrix('context')
     48         context_mask = tensor.imatrix('context_mask')
     49         answer = tensor.ivector('answer')
     50         candidates = tensor.imatrix('candidates')
     51         candidates_mask = tensor.imatrix('candidates_mask')
     52 
     53         bricks = []
     54 
     55         question = question.dimshuffle(1, 0)
     56         question_mask = question_mask.dimshuffle(1, 0)
     57         context = context.dimshuffle(1, 0)
     58         context_mask = context_mask.dimshuffle(1, 0)
     59 
     60         # Embed questions and cntext
     61         embed = LookupTable(vocab_size, config.embed_size, name='question_embed')
     62         bricks.append(embed)
     63 
     64         qembed = embed.apply(question)
     65         cembed = embed.apply(context)
     66 
     67         qlstms, qhidden_list = make_bidir_lstm_stack(qembed, config.embed_size, question_mask.astype(theano.config.floatX),
     68                                                      config.question_lstm_size, config.question_skip_connections, 'q')
     69         clstms, chidden_list = make_bidir_lstm_stack(cembed, config.embed_size, context_mask.astype(theano.config.floatX),
     70                                                      config.ctx_lstm_size, config.ctx_skip_connections, 'ctx')
     71         bricks = bricks + qlstms + clstms
     72 
     73         # Calculate question encoding (concatenate layer1)
     74         if config.question_skip_connections:
     75             qenc_dim = 2*sum(config.question_lstm_size)
     76             qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list], axis=1)
     77         else:
     78             qenc_dim = 2*config.question_lstm_size[-1]
     79             qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list[-2:]], axis=1)
     80         qenc.name = 'qenc'
     81 
     82         # Calculate context encoding (concatenate layer1)
     83         if config.ctx_skip_connections:
     84             cenc_dim = 2*sum(config.ctx_lstm_size)
     85             cenc = tensor.concatenate(chidden_list, axis=2)
     86         else:
     87             cenc_dim = 2*config.ctx_lstm_size[-1]
     88             cenc = tensor.concatenate(chidden_list[-2:], axis=2)
     89         cenc.name = 'cenc'
     90 
     91         # Attention mechanism MLP
     92         attention_mlp = MLP(dims=config.attention_mlp_hidden + [1],
     93                             activations=config.attention_mlp_activations[1:] + [Identity()],
     94                             name='attention_mlp')
     95         attention_qlinear = Linear(input_dim=qenc_dim, output_dim=config.attention_mlp_hidden[0], name='attq')
     96         attention_clinear = Linear(input_dim=cenc_dim, output_dim=config.attention_mlp_hidden[0], use_bias=False, name='attc')
     97         bricks += [attention_mlp, attention_qlinear, attention_clinear]
     98         layer1 = Tanh().apply(attention_clinear.apply(cenc.reshape((cenc.shape[0]*cenc.shape[1], cenc.shape[2])))
     99                                         .reshape((cenc.shape[0],cenc.shape[1],config.attention_mlp_hidden[0]))
    100                              + attention_qlinear.apply(qenc)[None, :, :])
    101         layer1.name = 'layer1'
    102         att_weights = attention_mlp.apply(layer1.reshape((layer1.shape[0]*layer1.shape[1], layer1.shape[2])))
    103         att_weights.name = 'att_weights_0'
    104         att_weights = att_weights.reshape((layer1.shape[0], layer1.shape[1]))
    105         att_weights.name = 'att_weights'
    106 
    107         attended = tensor.sum(cenc * tensor.nnet.softmax(att_weights.T).T[:, :, None], axis=0)
    108         attended.name = 'attended'
    109 
    110         # Now we can calculate our output
    111         out_mlp = MLP(dims=[cenc_dim + qenc_dim] + config.out_mlp_hidden + [config.n_entities],
    112                       activations=config.out_mlp_activations + [Identity()],
    113                       name='out_mlp')
    114         bricks += [out_mlp]
    115         probs = out_mlp.apply(tensor.concatenate([attended, qenc], axis=1))
    116         probs.name = 'probs'
    117 
    118         is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :],
    119                                  tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1)
    120         probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs))
    121 
    122         # Calculate prediction, cost and error rate
    123         pred = probs.argmax(axis=1)
    124         cost = Softmax().categorical_cross_entropy(answer, probs).mean()
    125         error_rate = tensor.neq(answer, pred).mean()
    126 
    127         # Apply dropout
    128         cg = ComputationGraph([cost, error_rate])
    129         if config.w_noise > 0:
    130             noise_vars = VariableFilter(roles=[WEIGHT])(cg)
    131             cg = apply_noise(cg, noise_vars, config.w_noise)
    132         if config.dropout > 0:
    133             cg = apply_dropout(cg, qhidden_list + chidden_list, config.dropout)
    134         [cost_reg, error_rate_reg] = cg.outputs
    135 
    136         # Other stuff
    137         cost_reg.name = cost.name = 'cost'
    138         error_rate_reg.name = error_rate.name = 'error_rate'
    139 
    140         self.sgd_cost = cost_reg
    141         self.monitor_vars = [[cost_reg], [error_rate_reg]]
    142         self.monitor_vars_valid = [[cost], [error_rate]]
    143 
    144         # Initialize bricks
    145         for brick in bricks:
    146             brick.weights_init = config.weights_init
    147             brick.biases_init = config.biases_init
    148             brick.initialize()
    149 
    150         
    151 
    152 #  vim: set sts=4 ts=4 sw=4 tw=0 et :