# -*- coding: utf-8 -*-
import os
import torch
import torch.nn as nn
from supar.models import (BiaffineDependencyModel, CRF2oDependencyModel,
CRFDependencyModel, VIDependencyModel)
from supar.parsers.parser import Parser
from supar.structs import Dependency2oCRF, DependencyCRF, MatrixTree
from supar.utils import Config, Dataset, Embedding
from supar.utils.common import BOS, PAD, UNK
from supar.utils.field import ChartField, Field, RawField, SubwordField
from supar.utils.fn import ispunct
from supar.utils.logging import get_logger, progress_bar
from supar.utils.metric import AttachmentMetric
from supar.utils.transform import CoNLL
logger = get_logger(__name__)
[docs]class BiaffineDependencyParser(Parser):
r"""
The implementation of Biaffine Dependency Parser :cite:`dozat-etal-2017-biaffine`.
"""
NAME = 'biaffine-dependency'
MODEL = BiaffineDependencyModel
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.TAG = self.transform.CPOS
self.ARC, self.REL = self.transform.HEAD, self.transform.DEPREL
[docs] def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
punct=False, tree=False, proj=False, partial=False, verbose=True, **kwargs):
r"""
Args:
train/dev/test (list[list] or str):
Filenames of the train/dev/test datasets.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
update_steps (int):
Gradient accumulation steps. Default: 1.
punct (bool):
If ``False``, ignores the punctuation during evaluation. Default: ``False``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
partial (bool):
``True`` denotes the trees are partially annotated. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating training configs.
"""
return super().train(**Config().update(locals()))
[docs] def evaluate(self, data, buckets=8, batch_size=5000,
punct=False, tree=True, proj=False, partial=False, verbose=True, **kwargs):
r"""
Args:
data (str):
The data for evaluation, both list of instances and filename are allowed.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
punct (bool):
If ``False``, ignores the punctuation during evaluation. Default: ``False``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
partial (bool):
``True`` denotes the trees are partially annotated. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating evaluation configs.
Returns:
The loss scalar and evaluation results.
"""
return super().evaluate(**Config().update(locals()))
[docs] def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=False,
tree=True, proj=False, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
``None`` if tokenization is not required.
Default: ``None``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
prob (bool):
If ``True``, outputs the probabilities. Default: ``False``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating prediction configs.
Returns:
A :class:`~supar.utils.Dataset` object that stores the predicted results.
"""
return super().predict(**Config().update(locals()))
[docs] @classmethod
def load(cls, path, reload=False, src='github', **kwargs):
r"""
Loads a parser with data fields and pretrained model parameters.
Args:
path (str):
- a string with the shortcut name of a pretrained model defined in ``supar.MODEL``
to load from cache or download, e.g., ``'biaffine-dep-en'``.
- a local path to a pretrained model, e.g., ``./<path>/model``.
reload (bool):
Whether to discard the existing cache and force a fresh download. Default: ``False``.
src (str):
Specifies where to download the model.
``'github'``: github release page.
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
Default: ``'github'``.
kwargs (dict):
A dict holding unconsumed arguments for updating training configs and initializing the model.
Examples:
>>> from supar import Parser
>>> parser = Parser.load('biaffine-dep-en')
>>> parser = Parser.load('./ptb.biaffine.dep.lstm.char')
"""
return super().load(path, reload, src, **kwargs)
def _train(self, loader):
self.model.train()
bar, metric = progress_bar(loader), AttachmentMetric()
for i, batch in enumerate(bar, 1):
words, texts, *feats, arcs, rels = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_rel = self.model(words, feats)
loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial)
loss = loss / self.args.update_steps
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
if i % self.args.update_steps == 0:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask)
if self.args.partial:
mask &= arcs.ge(0)
# ignore all punctuation if not specified
if not self.args.punct:
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
metric(arc_preds, rel_preds, arcs, rels, mask)
bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
logger.info(f"{bar.postfix}")
@torch.no_grad()
def _evaluate(self, loader):
self.model.eval()
total_loss, metric = 0, AttachmentMetric()
for batch in loader:
words, texts, *feats, arcs, rels = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_rel = self.model(words, feats)
loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial)
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
if self.args.partial:
mask &= arcs.ge(0)
# ignore all punctuation if not specified
if not self.args.punct:
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
total_loss += loss.item()
metric(arc_preds, rel_preds, arcs, rels, mask)
total_loss /= len(loader)
return total_loss, metric
@torch.no_grad()
def _predict(self, loader):
self.model.eval()
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
for batch in progress_bar(loader):
words, texts, *feats = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
lens = mask.sum(1).tolist()
s_arc, s_rel = self.model(words, feats)
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
preds['arcs'].extend(arc_preds[mask].split(lens))
preds['rels'].extend(rel_preds[mask].split(lens))
if self.args.prob:
preds['probs'].extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.softmax(-1).unbind())])
preds['arcs'] = [seq.tolist() for seq in preds['arcs']]
preds['rels'] = [self.REL.vocab[seq.tolist()] for seq in preds['rels']]
return preds
[docs] @classmethod
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
r"""
Build a brand-new Parser, including initialization of all data fields and model parameters.
Args:
path (str):
The path of the model to be saved.
min_freq (str):
The minimum frequency needed to include a token in the vocabulary.
Required if taking words as encoder input.
Default: 2.
fix_len (int):
The max length of all subword pieces. The excess part of each piece will be truncated.
Required if using CharLSTM/BERT.
Default: 20.
kwargs (dict):
A dict holding the unconsumed arguments.
"""
args = Config(**locals())
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.makedirs(os.path.dirname(path) or './', exist_ok=True)
if os.path.exists(path) and not args.build:
parser = cls.load(**args)
parser.model = cls.MODEL(**parser.args)
parser.model.load_pretrained(parser.WORD.embed).to(args.device)
return parser
logger.info("Building the fields")
TAG, CHAR, ELMO, BERT = None, None, None, None
if args.encoder == 'bert':
from transformers import (AutoTokenizer, GPT2Tokenizer,
GPT2TokenizerFast)
t = AutoTokenizer.from_pretrained(args.bert)
WORD = SubwordField('words',
pad=t.pad_token,
unk=t.unk_token,
bos=t.bos_token or t.cls_token,
fix_len=args.fix_len,
tokenize=t.tokenize,
fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' '+x)
WORD.vocab = t.get_vocab()
else:
WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True)
if 'tag' in args.feat:
TAG = Field('tags', bos=BOS)
if 'char' in args.feat:
CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len)
if 'elmo' in args.feat:
from allennlp.modules.elmo import batch_to_ids
ELMO = RawField('elmo')
ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device)
if 'bert' in args.feat:
from transformers import (AutoTokenizer, GPT2Tokenizer,
GPT2TokenizerFast)
t = AutoTokenizer.from_pretrained(args.bert)
BERT = SubwordField('bert',
pad=t.pad_token,
unk=t.unk_token,
bos=t.bos_token or t.cls_token,
fix_len=args.fix_len,
tokenize=t.tokenize,
fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' '+x)
BERT.vocab = t.get_vocab()
TEXT = RawField('texts')
ARC = Field('arcs', bos=BOS, use_vocab=False, fn=CoNLL.get_arcs)
REL = Field('rels', bos=BOS)
transform = CoNLL(FORM=(WORD, TEXT, CHAR, ELMO, BERT), CPOS=TAG, HEAD=ARC, DEPREL=REL)
train = Dataset(transform, args.train)
if args.encoder != 'bert':
WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None))
if TAG is not None:
TAG.build(train)
if CHAR is not None:
CHAR.build(train)
REL.build(train)
args.update({
'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
'n_rels': len(REL.vocab),
'n_tags': len(TAG.vocab) if TAG is not None else None,
'n_chars': len(CHAR.vocab) if CHAR is not None else None,
'char_pad_index': CHAR.pad_index if CHAR is not None else None,
'bert_pad_index': BERT.pad_index if BERT is not None else None,
'pad_index': WORD.pad_index,
'unk_index': WORD.unk_index,
'bos_index': WORD.bos_index
})
logger.info(f"{transform}")
logger.info("Building the model")
model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None).to(args.device)
logger.info(f"{model}\n")
return cls(args, model, transform)
[docs]class CRFDependencyParser(BiaffineDependencyParser):
r"""
The implementation of first-order CRF Dependency Parser :cite:`zhang-etal-2020-efficient`.
"""
NAME = 'crf-dependency'
MODEL = CRFDependencyModel
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs] def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
punct=False, mbr=True, tree=False, proj=False, partial=False, verbose=True, **kwargs):
r"""
Args:
train/dev/test (list[list] or str):
Filenames of the train/dev/test datasets.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
update_steps (int):
Gradient accumulation steps. Default: 1.
punct (bool):
If ``False``, ignores the punctuation during evaluation. Default: ``False``.
mbr (bool):
If ``True``, performs MBR decoding. Default: ``True``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
partial (bool):
``True`` denotes the trees are partially annotated. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating training configs.
"""
return super().train(**Config().update(locals()))
[docs] def evaluate(self, data, buckets=8, batch_size=5000, punct=False,
mbr=True, tree=True, proj=True, partial=False, verbose=True, **kwargs):
r"""
Args:
data (str):
The data for evaluation, both list of instances and filename are allowed.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
punct (bool):
If ``False``, ignores the punctuation during evaluation. Default: ``False``.
mbr (bool):
If ``True``, performs MBR decoding. Default: ``True``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
partial (bool):
``True`` denotes the trees are partially annotated. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating evaluation configs.
Returns:
The loss scalar and evaluation results.
"""
return super().evaluate(**Config().update(locals()))
[docs] def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=False,
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
``None`` if tokenization is not required.
Default: ``None``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
prob (bool):
If ``True``, outputs the probabilities. Default: ``False``.
mbr (bool):
If ``True``, performs MBR decoding. Default: ``True``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating prediction configs.
Returns:
A :class:`~supar.utils.Dataset` object that stores the predicted results.
"""
return super().predict(**Config().update(locals()))
[docs] @classmethod
def load(cls, path, reload=False, src='github', **kwargs):
r"""
Loads a parser with data fields and pretrained model parameters.
Args:
path (str):
- a string with the shortcut name of a pretrained model defined in ``supar.MODEL``
to load from cache or download, e.g., ``'crf-dep-en'``.
- a local path to a pretrained model, e.g., ``./<path>/model``.
reload (bool):
Whether to discard the existing cache and force a fresh download. Default: ``False``.
src (str):
Specifies where to download the model.
``'github'``: github release page.
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
Default: ``'github'``.
kwargs (dict):
A dict holding unconsumed arguments for updating training configs and initializing the model.
Examples:
>>> from supar import Parser
>>> parser = Parser.load('crf-dep-en')
>>> parser = Parser.load('./ptb.crf.dep.lstm.char')
"""
return super().load(path, reload, src, **kwargs)
def _train(self, loader):
self.model.train()
bar, metric = progress_bar(loader), AttachmentMetric()
for i, batch in enumerate(bar, 1):
words, texts, *feats, arcs, rels = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_rel = self.model(words, feats)
loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.mbr, self.args.partial)
loss = loss / self.args.update_steps
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
if i % self.args.update_steps == 0:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask)
if self.args.partial:
mask &= arcs.ge(0)
# ignore all punctuation if not specified
if not self.args.punct:
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
metric(arc_preds, rel_preds, arcs, rels, mask)
bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
logger.info(f"{bar.postfix}")
@torch.no_grad()
def _evaluate(self, loader):
self.model.eval()
total_loss, metric = 0, AttachmentMetric()
for batch in loader:
words, texts, *feats, arcs, rels = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_rel = self.model(words, feats)
loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.mbr, self.args.partial)
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
if self.args.partial:
mask &= arcs.ge(0)
# ignore all punctuation if not specified
if not self.args.punct:
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
total_loss += loss.item()
metric(arc_preds, rel_preds, arcs, rels, mask)
total_loss /= len(loader)
return total_loss, metric
@torch.no_grad()
def _predict(self, loader):
self.model.eval()
CRF = DependencyCRF if self.args.proj else MatrixTree
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
for batch in progress_bar(loader):
words, _, *feats = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
lens = mask.sum(1)
s_arc, s_rel = self.model(words, feats)
s_arc = CRF(s_arc, lens).marginals if self.args.mbr else s_arc
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
lens = lens.tolist()
preds['arcs'].extend(arc_preds[mask].split(lens))
preds['rels'].extend(rel_preds[mask].split(lens))
if self.args.prob:
arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1)
preds['probs'].extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())])
preds['arcs'] = [seq.tolist() for seq in preds['arcs']]
preds['rels'] = [self.REL.vocab[seq.tolist()] for seq in preds['rels']]
return preds
[docs]class CRF2oDependencyParser(BiaffineDependencyParser):
r"""
The implementation of second-order CRF Dependency Parser :cite:`zhang-etal-2020-efficient`.
"""
NAME = 'crf2o-dependency'
MODEL = CRF2oDependencyModel
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs] def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
punct=False, mbr=True, tree=False, proj=False, partial=False, verbose=True, **kwargs):
r"""
Args:
train/dev/test (list[list] or str):
Filenames of the train/dev/test datasets.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
update_steps (int):
Gradient accumulation steps. Default: 1.
punct (bool):
If ``False``, ignores the punctuation during evaluation. Default: ``False``.
mbr (bool):
If ``True``, performs MBR decoding. Default: ``True``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
partial (bool):
``True`` denotes the trees are partially annotated. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating training configs.
"""
return super().train(**Config().update(locals()))
[docs] def evaluate(self, data, buckets=8, batch_size=5000, punct=False,
mbr=True, tree=True, proj=True, partial=False, verbose=True, **kwargs):
r"""
Args:
data (str):
The data for evaluation, both list of instances and filename are allowed.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
punct (bool):
If ``False``, ignores the punctuation during evaluation. Default: ``False``.
mbr (bool):
If ``True``, performs MBR decoding. Default: ``True``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
partial (bool):
``True`` denotes the trees are partially annotated. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating evaluation configs.
Returns:
The loss scalar and evaluation results.
"""
return super().evaluate(**Config().update(locals()))
[docs] def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=False,
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
``None`` if tokenization is not required.
Default: ``None``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
prob (bool):
If ``True``, outputs the probabilities. Default: ``False``.
mbr (bool):
If ``True``, performs MBR decoding. Default: ``True``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating prediction configs.
Returns:
A :class:`~supar.utils.Dataset` object that stores the predicted results.
"""
return super().predict(**Config().update(locals()))
[docs] @classmethod
def load(cls, path, reload=False, src='github', **kwargs):
r"""
Loads a parser with data fields and pretrained model parameters.
Args:
path (str):
- a string with the shortcut name of a pretrained model defined in ``supar.MODEL``
to load from cache or download, e.g., ``'crf2o-dep-en'``.
- a local path to a pretrained model, e.g., ``./<path>/model``.
reload (bool):
Whether to discard the existing cache and force a fresh download. Default: ``False``.
src (str):
Specifies where to download the model.
``'github'``: github release page.
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
Default: ``'github'``.
kwargs (dict):
A dict holding unconsumed arguments for updating training configs and initializing the model.
Examples:
>>> from supar import Parser
>>> parser = Parser.load('crf2o-dep-en')
>>> parser = Parser.load('./ptb.crf2o.dep.lstm.char')
"""
return super().load(path, reload, src, **kwargs)
def _train(self, loader):
self.model.train()
bar, metric = progress_bar(loader), AttachmentMetric()
for i, batch in enumerate(bar, 1):
words, texts, *feats, arcs, sibs, rels = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_sib, s_rel = self.model(words, feats)
loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial)
loss = loss / self.args.update_steps
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
if i % self.args.update_steps == 0:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask)
if self.args.partial:
mask &= arcs.ge(0)
# ignore all punctuation if not specified
if not self.args.punct:
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
metric(arc_preds, rel_preds, arcs, rels, mask)
bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
logger.info(f"{bar.postfix}")
@torch.no_grad()
def _evaluate(self, loader):
self.model.eval()
total_loss, metric = 0, AttachmentMetric()
for batch in loader:
words, texts, *feats, arcs, sibs, rels = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_sib, s_rel = self.model(words, feats)
loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial)
arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj)
if self.args.partial:
mask &= arcs.ge(0)
# ignore all punctuation if not specified
if not self.args.punct:
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
total_loss += loss.item()
metric(arc_preds, rel_preds, arcs, rels, mask)
total_loss /= len(loader)
return total_loss, metric
@torch.no_grad()
def _predict(self, loader):
self.model.eval()
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
for batch in progress_bar(loader):
words, texts, *feats = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
lens = mask.sum(1)
s_arc, s_sib, s_rel = self.model(words, feats)
s_arc = Dependency2oCRF((s_arc, s_sib), lens).marginals if self.args.mbr else s_arc
arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj)
lens = lens.tolist()
preds['arcs'].extend(arc_preds[mask].split(lens))
preds['rels'].extend(rel_preds[mask].split(lens))
if self.args.prob:
arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1)
preds['probs'].extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())])
preds['arcs'] = [seq.tolist() for seq in preds['arcs']]
preds['rels'] = [self.REL.vocab[seq.tolist()] for seq in preds['rels']]
return preds
[docs] @classmethod
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
r"""
Build a brand-new Parser, including initialization of all data fields and model parameters.
Args:
path (str):
The path of the model to be saved.
min_freq (str):
The minimum frequency needed to include a token in the vocabulary. Default: 2.
fix_len (int):
The max length of all subword pieces. The excess part of each piece will be truncated.
Required if using CharLSTM/BERT.
Default: 20.
kwargs (dict):
A dict holding the unconsumed arguments.
"""
args = Config(**locals())
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.makedirs(os.path.dirname(path) or './', exist_ok=True)
if os.path.exists(path) and not args.build:
parser = cls.load(**args)
parser.model = cls.MODEL(**parser.args)
parser.model.load_pretrained(parser.WORD.embed).to(args.device)
return parser
logger.info("Building the fields")
TAG, CHAR, ELMO, BERT = None, None, None, None
if args.encoder == 'bert':
from transformers import (AutoTokenizer, GPT2Tokenizer,
GPT2TokenizerFast)
t = AutoTokenizer.from_pretrained(args.bert)
WORD = SubwordField('words',
pad=t.pad_token,
unk=t.unk_token,
bos=t.bos_token or t.cls_token,
fix_len=args.fix_len,
tokenize=t.tokenize,
fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' '+x)
WORD.vocab = t.get_vocab()
else:
WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True)
if 'tag' in args.feat:
TAG = Field('tags', bos=BOS)
if 'char' in args.feat:
CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len)
if 'elmo' in args.feat:
from allennlp.modules.elmo import batch_to_ids
ELMO = RawField('elmo')
ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device)
if 'bert' in args.feat:
from transformers import (AutoTokenizer, GPT2Tokenizer,
GPT2TokenizerFast)
t = AutoTokenizer.from_pretrained(args.bert)
BERT = SubwordField('bert',
pad=t.pad_token,
unk=t.unk_token,
bos=t.bos_token or t.cls_token,
fix_len=args.fix_len,
tokenize=t.tokenize,
fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' '+x)
BERT.vocab = t.get_vocab()
TEXT = RawField('texts')
ARC = Field('arcs', bos=BOS, use_vocab=False, fn=CoNLL.get_arcs)
SIB = ChartField('sibs', bos=BOS, use_vocab=False, fn=CoNLL.get_sibs)
REL = Field('rels', bos=BOS)
transform = CoNLL(FORM=(WORD, TEXT, CHAR, ELMO, BERT), CPOS=TAG, HEAD=(ARC, SIB), DEPREL=REL)
train = Dataset(transform, args.train)
if args.encoder != 'bert':
WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None))
if TAG is not None:
TAG.build(train)
if CHAR is not None:
CHAR.build(train)
REL.build(train)
args.update({
'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
'n_rels': len(REL.vocab),
'n_tags': len(TAG.vocab) if TAG is not None else None,
'n_chars': len(CHAR.vocab) if CHAR is not None else None,
'char_pad_index': CHAR.pad_index if CHAR is not None else None,
'bert_pad_index': BERT.pad_index if BERT is not None else None,
'pad_index': WORD.pad_index,
'unk_index': WORD.unk_index,
'bos_index': WORD.bos_index
})
logger.info(f"{transform}")
logger.info("Building the model")
model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None).to(args.device)
logger.info(f"{model}\n")
return cls(args, model, transform)
[docs]class VIDependencyParser(BiaffineDependencyParser):
r"""
The implementation of Dependency Parser using Variational Inference :cite:`wang-tu-2020-second`.
"""
NAME = 'vi-dependency'
MODEL = VIDependencyModel
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs] def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
punct=False, tree=False, proj=False, partial=False, verbose=True, **kwargs):
r"""
Args:
train/dev/test (list[list] or str):
Filenames of the train/dev/test datasets.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
update_steps (int):
Gradient accumulation steps. Default: 1.
punct (bool):
If ``False``, ignores the punctuation during evaluation. Default: ``False``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
partial (bool):
``True`` denotes the trees are partially annotated. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating training configs.
"""
return super().train(**Config().update(locals()))
[docs] def evaluate(self, data, buckets=8, batch_size=5000, punct=False,
tree=True, proj=True, partial=False, verbose=True, **kwargs):
r"""
Args:
data (str):
The data for evaluation, both list of instances and filename are allowed.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
punct (bool):
If ``False``, ignores the punctuation during evaluation. Default: ``False``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
partial (bool):
``True`` denotes the trees are partially annotated. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating evaluation configs.
Returns:
The loss scalar and evaluation results.
"""
return super().evaluate(**Config().update(locals()))
[docs] def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=False,
tree=True, proj=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
``None`` if tokenization is not required.
Default: ``None``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
prob (bool):
If ``True``, outputs the probabilities. Default: ``False``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding unconsumed arguments for updating prediction configs.
Returns:
A :class:`~supar.utils.Dataset` object that stores the predicted results.
"""
return super().predict(**Config().update(locals()))
[docs] @classmethod
def load(cls, path, reload=False, src='github', **kwargs):
r"""
Loads a parser with data fields and pretrained model parameters.
Args:
path (str):
- a string with the shortcut name of a pretrained model defined in ``supar.MODEL``
to load from cache or download, e.g., ``'vi-dep-en'``.
- a local path to a pretrained model, e.g., ``./<path>/model``.
reload (bool):
Whether to discard the existing cache and force a fresh download. Default: ``False``.
src (str):
Specifies where to download the model.
``'github'``: github release page.
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
Default: ``'github'``.
kwargs (dict):
A dict holding unconsumed arguments for updating training configs and initializing the model.
Examples:
>>> from supar import Parser
>>> parser = Parser.load('vi-dep-en')
>>> parser = Parser.load('./ptb.vi.dep.lstm.char')
"""
return super().load(path, reload, src, **kwargs)
def _train(self, loader):
self.model.train()
bar, metric = progress_bar(loader), AttachmentMetric()
for i, batch in enumerate(bar, 1):
words, texts, *feats, arcs, rels = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_sib, s_rel = self.model(words, feats)
loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask)
loss = loss / self.args.update_steps
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
if i % self.args.update_steps == 0:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask)
if self.args.partial:
mask &= arcs.ge(0)
# ignore all punctuation if not specified
if not self.args.punct:
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
metric(arc_preds, rel_preds, arcs, rels, mask)
bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
logger.info(f"{bar.postfix}")
@torch.no_grad()
def _evaluate(self, loader):
self.model.eval()
total_loss, metric = 0, AttachmentMetric()
for batch in loader:
words, texts, *feats, arcs, rels = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_sib, s_rel = self.model(words, feats)
loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask)
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
if self.args.partial:
mask &= arcs.ge(0)
# ignore all punctuation if not specified
if not self.args.punct:
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
total_loss += loss.item()
metric(arc_preds, rel_preds, arcs, rels, mask)
total_loss /= len(loader)
return total_loss, metric
@torch.no_grad()
def _predict(self, loader):
self.model.eval()
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
for batch in progress_bar(loader):
words, texts, *feats = batch
word_mask = words.ne(self.args.pad_index)
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
# ignore the first token of each sentence
mask[:, 0] = 0
lens = mask.sum(1).tolist()
s_arc, s_sib, s_rel = self.model(words, feats)
s_arc = self.model.inference((s_arc, s_sib), mask)
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
preds['arcs'].extend(arc_preds[mask].split(lens))
preds['rels'].extend(rel_preds[mask].split(lens))
if self.args.prob:
preds['probs'].extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.unbind())])
preds['arcs'] = [seq.tolist() for seq in preds['arcs']]
preds['rels'] = [self.REL.vocab[seq.tolist()] for seq in preds['rels']]
return preds