# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
[docs]class Biaffine(nn.Module):
r"""
Biaffine layer for first-order scoring :cite:`dozat-etal-2017-biaffine`.
This function has a tensor of weights :math:`W` and bias terms if needed.
The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y / d^s`,
where `d` and `s` are vector dimension and scaling factor respectively.
:math:`x` and :math:`y` can be concatenated with bias terms.
Args:
n_in (int):
The size of the input feature.
n_out (int):
The number of output channels.
scale (float):
Factor to scale the scores. Default: 0.
bias_x (bool):
If ``True``, adds a bias term for tensor :math:`x`. Default: ``True``.
bias_y (bool):
If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``.
"""
def __init__(self, n_in, n_out=1, scale=0, bias_x=True, bias_y=True):
super().__init__()
self.n_in = n_in
self.n_out = n_out
self.scale = scale
self.bias_x = bias_x
self.bias_y = bias_y
self.weight = nn.Parameter(torch.Tensor(n_out, n_in+bias_x, n_in+bias_y))
self.reset_parameters()
def __repr__(self):
s = f"n_in={self.n_in}"
if self.n_out > 1:
s += f", n_out={self.n_out}"
if self.scale != 0:
s += f", scale={self.scale}"
if self.bias_x:
s += f", bias_x={self.bias_x}"
if self.bias_y:
s += f", bias_y={self.bias_y}"
return f"{self.__class__.__name__}({s})"
def reset_parameters(self):
nn.init.zeros_(self.weight)
[docs] def forward(self, x, y):
r"""
Args:
x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
Returns:
~torch.Tensor:
A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len]``.
If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically.
"""
if self.bias_x:
x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
if self.bias_y:
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
# [batch_size, n_out, seq_len, seq_len]
s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y) / self.n_in ** self.scale
# remove dim 1 if n_out == 1
s = s.squeeze(1)
return s
[docs]class Triaffine(nn.Module):
r"""
Triaffine layer for second-order scoring :cite:`zhang-etal-2020-efficient,wang-etal-2019-second`.
This function has a tensor of weights :math:`W` and bias terms if needed.
The score :math:`s(x, y, z)` of the vector triple :math:`(x, y, z)` is computed as :math:`x^T z^T W y / d^s`,
where `d` and `s` are vector dimension and scaling factor respectively.
:math:`x` and :math:`y` can be concatenated with bias terms.
Args:
n_in (int):
The size of the input feature.
n_out (int):
The number of output channels.
scale (float):
Factor to scale the scores. Default: 0.
bias_x (bool):
If ``True``, adds a bias term for tensor :math:`x`. Default: ``False``.
bias_y (bool):
If ``True``, adds a bias term for tensor :math:`y`. Default: ``False``.
"""
def __init__(self, n_in, n_out=1, scale=0, bias_x=False, bias_y=False):
super().__init__()
self.n_in = n_in
self.n_out = n_out
self.scale = scale
self.bias_x = bias_x
self.bias_y = bias_y
self.weight = nn.Parameter(torch.Tensor(n_out, n_in+bias_x, n_in, n_in+bias_y))
self.reset_parameters()
def __repr__(self):
s = f"n_in={self.n_in}"
if self.n_out > 1:
s += f", n_out={self.n_out}"
if self.scale != 0:
s += f", scale={self.scale}"
if self.bias_x:
s += f", bias_x={self.bias_x}"
if self.bias_y:
s += f", bias_y={self.bias_y}"
return f"{self.__class__.__name__}({s})"
def reset_parameters(self):
nn.init.zeros_(self.weight)
[docs] def forward(self, x, y, z):
r"""
Args:
x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
z (torch.Tensor): ``[batch_size, seq_len, n_in]``.
Returns:
~torch.Tensor:
A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len, seq_len]``.
If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically.
"""
if self.bias_x:
x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
if self.bias_y:
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
w = torch.einsum('bzk,oikj->bozij', z, self.weight)
# [batch_size, n_out, seq_len, seq_len, seq_len]
s = torch.einsum('bxi,bozij,byj->bozxy', x, w, y) / self.n_in ** self.scale
# remove dim 1 if n_out == 1
s = s.squeeze(1)
return s