Source code for supar.structs.fn

# -*- coding: utf-8 -*-

import torch
from supar.utils.common import MIN
from supar.utils.fn import pad
from torch.autograd import Function


[docs]def tarjan(sequence): r""" Tarjan algorithm for finding Strongly Connected Components (SCCs) of a graph. Args: sequence (list): List of head indices. Yields: A list of indices making up a SCC. All self-loops are ignored. Examples: >>> next(tarjan([2, 5, 0, 3, 1])) # (1 -> 5 -> 2 -> 1) is a cycle [2, 5, 1] """ sequence = [-1] + sequence # record the search order, i.e., the timestep dfn = [-1] * len(sequence) # record the the smallest timestep in a SCC low = [-1] * len(sequence) # push the visited into the stack stack, onstack = [], [False] * len(sequence) def connect(i, timestep): dfn[i] = low[i] = timestep[0] timestep[0] += 1 stack.append(i) onstack[i] = True for j, head in enumerate(sequence): if head != i: continue if dfn[j] == -1: yield from connect(j, timestep) low[i] = min(low[i], low[j]) elif onstack[j]: low[i] = min(low[i], dfn[j]) # a SCC is completed if low[i] == dfn[i]: cycle = [stack.pop()] while cycle[-1] != i: onstack[cycle[-1]] = False cycle.append(stack.pop()) onstack[i] = False # ignore the self-loop if len(cycle) > 1: yield cycle timestep = [0] for i in range(len(sequence)): if dfn[i] == -1: yield from connect(i, timestep)
[docs]def chuliu_edmonds(s): r""" ChuLiu/Edmonds algorithm for non-projective decoding :cite:`mcdonald-etal-2005-non`. Some code is borrowed from `tdozat's implementation`_. Descriptions of notations and formulas can be found in :cite:`mcdonald-etal-2005-non`. Notes: The algorithm does not guarantee to parse a single-root tree. Args: s (~torch.Tensor): ``[seq_len, seq_len]``. Scores of all dependent-head pairs. Returns: ~torch.Tensor: A tensor with shape ``[seq_len]`` for the resulting non-projective parse tree. .. _tdozat's implementation: https://github.com/tdozat/Parser-v3 """ s[0, 1:] = MIN # prevent self-loops s.diagonal()[1:].fill_(MIN) # select heads with highest scores tree = s.argmax(-1) # return the cycle finded by tarjan algorithm lazily cycle = next(tarjan(tree.tolist()[1:]), None) # if the tree has no cycles, then it is a MST if not cycle: return tree # indices of cycle in the original tree cycle = torch.tensor(cycle) # indices of noncycle in the original tree noncycle = torch.ones(len(s)).index_fill_(0, cycle, 0) noncycle = torch.where(noncycle.gt(0))[0] def contract(s): # heads of cycle in original tree cycle_heads = tree[cycle] # scores of cycle in original tree s_cycle = s[cycle, cycle_heads] # calculate the scores of cycle's potential dependents # s(c->x) = max(s(x'->x)), x in noncycle and x' in cycle s_dep = s[noncycle][:, cycle] # find the best cycle head for each noncycle dependent deps = s_dep.argmax(1) # calculate the scores of cycle's potential heads # s(x->c) = max(s(x'->x) - s(a(x')->x') + s(cycle)), x in noncycle and x' in cycle # a(v) is the predecessor of v in cycle # s(cycle) = sum(s(a(v)->v)) s_head = s[cycle][:, noncycle] - s_cycle.view(-1, 1) + s_cycle.sum() # find the best noncycle head for each cycle dependent heads = s_head.argmax(0) contracted = torch.cat((noncycle, torch.tensor([-1]))) # calculate the scores of contracted graph s = s[contracted][:, contracted] # set the contracted graph scores of cycle's potential dependents s[:-1, -1] = s_dep[range(len(deps)), deps] # set the contracted graph scores of cycle's potential heads s[-1, :-1] = s_head[heads, range(len(heads))] return s, heads, deps # keep track of the endpoints of the edges into and out of cycle for reconstruction later s, heads, deps = contract(s) # y is the contracted tree y = chuliu_edmonds(s) # exclude head of cycle from y y, cycle_head = y[:-1], y[-1] # fix the subtree with no heads coming from the cycle # len(y) denotes heads coming from the cycle subtree = y < len(y) # add the nodes to the new tree tree[noncycle[subtree]] = noncycle[y[subtree]] # fix the subtree with heads coming from the cycle subtree = ~subtree # add the nodes to the tree tree[noncycle[subtree]] = cycle[deps[subtree]] # fix the root of the cycle cycle_root = heads[cycle_head] # break the cycle and add the root of the cycle to the tree tree[cycle[cycle_root]] = noncycle[cycle_head] return tree
[docs]def mst(scores, mask, multiroot=False): r""" MST algorithm for decoding non-projective trees. This is a wrapper for ChuLiu/Edmonds algorithm. The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots, If ``multiroot=True`` and there indeed exist multi-roots, the algorithm seeks to find best single-root trees by iterating all possible single-root trees parsed by ChuLiu/Edmonds. Otherwise the resulting trees are directly taken as the final outputs. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all dependent-head pairs. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask to avoid parsing over padding tokens. The first column serving as pseudo words for roots should be ``False``. multiroot (bool): Ensures to parse a single-root tree If ``False``. Returns: ~torch.Tensor: A tensor with shape ``[batch_size, seq_len]`` for the resulting non-projective parse trees. Examples: >>> scores = torch.tensor([[[-11.9436, -13.1464, -6.4789, -13.8917], [-60.6957, -60.2866, -48.6457, -63.8125], [-38.1747, -49.9296, -45.2733, -49.5571], [-19.7504, -23.9066, -9.9139, -16.2088]]]) >>> scores[:, 0, 1:] = MIN >>> scores.diagonal(0, 1, 2)[1:].fill_(MIN) >>> mask = torch.tensor([[False, True, True, True]]) >>> mst(scores, mask) tensor([[0, 2, 0, 2]]) """ batch_size, seq_len, _ = scores.shape scores = scores.cpu().unbind() preds = [] for i, length in enumerate(mask.sum(1).tolist()): s = scores[i][:length+1, :length+1] tree = chuliu_edmonds(s) roots = torch.where(tree[1:].eq(0))[0] + 1 if not multiroot and len(roots) > 1: s_root = s[:, 0] s_best = MIN s = s.index_fill(1, torch.tensor(0), MIN) for root in roots: s[:, 0] = MIN s[root, 0] = s_root[root] t = chuliu_edmonds(s) s_tree = s[1:].gather(1, t[1:].unsqueeze(-1)).sum() if s_tree > s_best: s_best, tree = s_tree, t preds.append(tree) return pad(preds, total_length=seq_len).to(mask.device)
class SampledLogsumexp(Function): @staticmethod def forward(ctx, x, dim=-1): ctx.dim = dim ctx.save_for_backward(x) return x.logsumexp(dim=dim) @staticmethod def backward(ctx, grad_output): from torch.distributions import OneHotCategorical x, dim = ctx.saved_tensors, ctx.dim if ctx.needs_input_grad[0]: return grad_output.unsqueeze(dim).mul(OneHotCategorical(logits=x.movedim(dim, -1)).sample().movedim(-1, dim)), None return None, None class SparsemaxFunction(Function): @staticmethod def forward(ctx, x, dim=-1): ctx.dim = dim sorted_x, _ = x.sort(dim, True) z = sorted_x.cumsum(dim) - 1 k = x.new_tensor(range(1, sorted_x.size(dim) + 1)).view(-1, *[1] * (x.dim() - 1)).transpose(0, dim) k = (k * sorted_x).gt(z).sum(dim, True) tau = z.gather(dim, k - 1) / k p = torch.clamp(x - tau, 0) ctx.save_for_backward(k, p) return p @staticmethod def backward(ctx, grad_output): k, p, dim = *ctx.saved_tensors, ctx.dim grad = grad_output.masked_fill(p.eq(0), 0) grad = torch.where(p.ne(0), grad - grad.sum(dim, True) / k, grad) return grad, None sampled_logsumexp = SampledLogsumexp.apply sparsemax = SparsemaxFunction.apply