# -*- coding: utf-8 -*-
from typing import List, Tuple, Union
import torch
from supar.utils.common import MIN
from supar.utils.fn import pad
from torch.autograd import Function
[docs]def tarjan(sequence: List[int]) -> List[int]:
r"""
Tarjan algorithm for finding Strongly Connected Components (SCCs) of a graph.
Args:
sequence (list):
List of head indices.
Yields:
A list of indices making up a SCC. All self-loops are ignored.
Examples:
>>> next(tarjan([2, 5, 0, 3, 1])) # (1 -> 5 -> 2 -> 1) is a cycle
[2, 5, 1]
"""
sequence = [-1] + sequence
# record the search order, i.e., the timestep
dfn = [-1] * len(sequence)
# record the the smallest timestep in a SCC
low = [-1] * len(sequence)
# push the visited into the stack
stack, onstack = [], [False] * len(sequence)
def connect(i, timestep):
dfn[i] = low[i] = timestep[0]
timestep[0] += 1
stack.append(i)
onstack[i] = True
for j, head in enumerate(sequence):
if head != i:
continue
if dfn[j] == -1:
yield from connect(j, timestep)
low[i] = min(low[i], low[j])
elif onstack[j]:
low[i] = min(low[i], dfn[j])
# a SCC is completed
if low[i] == dfn[i]:
cycle = [stack.pop()]
while cycle[-1] != i:
onstack[cycle[-1]] = False
cycle.append(stack.pop())
onstack[i] = False
# ignore the self-loop
if len(cycle) > 1:
yield cycle
timestep = [0]
for i in range(len(sequence)):
if dfn[i] == -1:
yield from connect(i, timestep)
[docs]def chuliu_edmonds(s: torch.Tensor) -> torch.Tensor:
r"""
ChuLiu/Edmonds algorithm for non-projective decoding :cite:`mcdonald-etal-2005-non`.
Some code is borrowed from `tdozat's implementation`_.
Descriptions of notations and formulas can be found in :cite:`mcdonald-etal-2005-non`.
Notes:
The algorithm does not guarantee to parse a single-root tree.
Args:
s (~torch.Tensor): ``[seq_len, seq_len]``.
Scores of all dependent-head pairs.
Returns:
~torch.Tensor:
A tensor with shape ``[seq_len]`` for the resulting non-projective parse tree.
.. _tdozat's implementation:
https://github.com/tdozat/Parser-v3
"""
s[0, 1:] = MIN
# prevent self-loops
s.diagonal()[1:].fill_(MIN)
# select heads with highest scores
tree = s.argmax(-1)
# return the cycle finded by tarjan algorithm lazily
cycle = next(tarjan(tree.tolist()[1:]), None)
# if the tree has no cycles, then it is a MST
if not cycle:
return tree
# indices of cycle in the original tree
cycle = torch.tensor(cycle)
# indices of noncycle in the original tree
noncycle = torch.ones(len(s)).index_fill_(0, cycle, 0)
noncycle = torch.where(noncycle.gt(0))[0]
def contract(s):
# heads of cycle in original tree
cycle_heads = tree[cycle]
# scores of cycle in original tree
s_cycle = s[cycle, cycle_heads]
# calculate the scores of cycle's potential dependents
# s(c->x) = max(s(x'->x)), x in noncycle and x' in cycle
s_dep = s[noncycle][:, cycle]
# find the best cycle head for each noncycle dependent
deps = s_dep.argmax(1)
# calculate the scores of cycle's potential heads
# s(x->c) = max(s(x'->x) - s(a(x')->x') + s(cycle)), x in noncycle and x' in cycle
# a(v) is the predecessor of v in cycle
# s(cycle) = sum(s(a(v)->v))
s_head = s[cycle][:, noncycle] - s_cycle.view(-1, 1) + s_cycle.sum()
# find the best noncycle head for each cycle dependent
heads = s_head.argmax(0)
contracted = torch.cat((noncycle, torch.tensor([-1])))
# calculate the scores of contracted graph
s = s[contracted][:, contracted]
# set the contracted graph scores of cycle's potential dependents
s[:-1, -1] = s_dep[range(len(deps)), deps]
# set the contracted graph scores of cycle's potential heads
s[-1, :-1] = s_head[heads, range(len(heads))]
return s, heads, deps
# keep track of the endpoints of the edges into and out of cycle for reconstruction later
s, heads, deps = contract(s)
# y is the contracted tree
y = chuliu_edmonds(s)
# exclude head of cycle from y
y, cycle_head = y[:-1], y[-1]
# fix the subtree with no heads coming from the cycle
# len(y) denotes heads coming from the cycle
subtree = y < len(y)
# add the nodes to the new tree
tree[noncycle[subtree]] = noncycle[y[subtree]]
# fix the subtree with heads coming from the cycle
subtree = ~subtree
# add the nodes to the tree
tree[noncycle[subtree]] = cycle[deps[subtree]]
# fix the root of the cycle
cycle_root = heads[cycle_head]
# break the cycle and add the root of the cycle to the tree
tree[cycle[cycle_root]] = noncycle[cycle_head]
return tree
[docs]def mst(scores: torch.Tensor, mask: torch.BoolTensor, multiroot: bool = False) -> torch.Tensor:
r"""
MST algorithm for decoding non-projective trees.
This is a wrapper for ChuLiu/Edmonds algorithm.
The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots,
If ``multiroot=True`` and there indeed exist multi-roots, the algorithm seeks to find
best single-root trees by iterating all possible single-root trees parsed by ChuLiu/Edmonds.
Otherwise the resulting trees are directly taken as the final outputs.
Args:
scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
Scores of all dependent-head pairs.
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
The mask to avoid parsing over padding tokens.
The first column serving as pseudo words for roots should be ``False``.
multiroot (bool):
Ensures to parse a single-root tree If ``False``.
Returns:
~torch.Tensor:
A tensor with shape ``[batch_size, seq_len]`` for the resulting non-projective parse trees.
Examples:
>>> scores = torch.tensor([[[-11.9436, -13.1464, -6.4789, -13.8917],
[-60.6957, -60.2866, -48.6457, -63.8125],
[-38.1747, -49.9296, -45.2733, -49.5571],
[-19.7504, -23.9066, -9.9139, -16.2088]]])
>>> scores[:, 0, 1:] = MIN
>>> scores.diagonal(0, 1, 2)[1:].fill_(MIN)
>>> mask = torch.tensor([[False, True, True, True]])
>>> mst(scores, mask)
tensor([[0, 2, 0, 2]])
"""
_, seq_len, _ = scores.shape
scores = scores.cpu().unbind()
preds = []
for i, length in enumerate(mask.sum(1).tolist()):
s = scores[i][:length+1, :length+1]
tree = chuliu_edmonds(s)
roots = torch.where(tree[1:].eq(0))[0] + 1
if not multiroot and len(roots) > 1:
s_root = s[:, 0]
s_best = MIN
s = s.index_fill(1, torch.tensor(0), MIN)
for root in roots:
s[:, 0] = MIN
s[root, 0] = s_root[root]
t = chuliu_edmonds(s)
s_tree = s[1:].gather(1, t[1:].unsqueeze(-1)).sum()
if s_tree > s_best:
s_best, tree = s_tree, t
preds.append(tree)
return pad(preds, total_length=seq_len).to(mask.device)
class Logsumexp(Function):
r"""
Safer ``logsumexp`` to cure unnecessary NaN values that arise from inf arguments.
See discussions at http://github.com/pytorch/pytorch/issues/49724.
To be optimized with C++/Cuda extensions.
"""
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float)
def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
output = x.logsumexp(dim)
ctx.dim = dim
ctx.save_for_backward(x, output)
return output.clone()
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]:
x, output, dim = *ctx.saved_tensors, ctx.dim
g, output = g.unsqueeze(dim), output.unsqueeze(dim)
mask = g.eq(0).expand_as(x)
grad = g * (x - output).exp()
return torch.where(mask, x.new_tensor(0.), grad), None
class Logaddexp(Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float)
def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.logaddexp(x, y)
ctx.save_for_backward(x, y, output)
return output.clone()
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]:
x, y, output = ctx.saved_tensors
mask = g.eq(0)
grad_x, grad_y = (x - output).exp(), (y - output).exp()
grad_x = torch.where(mask, x.new_tensor(0.), grad_x)
grad_y = torch.where(mask, y.new_tensor(0.), grad_y)
return grad_x, grad_y
class SampledLogsumexp(Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float)
def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
ctx.dim = dim
ctx.save_for_backward(x)
return x.logsumexp(dim=dim)
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]:
from torch.distributions import OneHotCategorical
x, dim = ctx.saved_tensors, ctx.dim
return g.unsqueeze(dim).mul(OneHotCategorical(logits=x.movedim(dim, -1)).sample().movedim(-1, dim)), None
class Sparsemax(Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float)
def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
ctx.dim = dim
sorted_x, _ = x.sort(dim, True)
z = sorted_x.cumsum(dim) - 1
k = x.new_tensor(range(1, sorted_x.size(dim) + 1)).view(-1, *[1] * (x.dim() - 1)).transpose(0, dim)
k = (k * sorted_x).gt(z).sum(dim, True)
tau = z.gather(dim, k - 1) / k
p = torch.clamp(x - tau, 0)
ctx.save_for_backward(k, p)
return p
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, g: torch.Tensor) -> Tuple[torch.Tensor, None]:
k, p, dim = *ctx.saved_tensors, ctx.dim
grad = g.masked_fill(p.eq(0), 0)
grad = torch.where(p.ne(0), grad - grad.sum(dim, True) / k, grad)
return grad, None
logsumexp = Logsumexp.apply
logaddexp = Logaddexp.apply
sampled_logsumexp = SampledLogsumexp.apply
sparsemax = Sparsemax.apply