# coding:utf-8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import numpy as np
import paddle
from ..data import Pad, Vocab
from .models import BiAffineParser
from .task import Task
from .utils import download_file
usage = r"""
from paddlenlp import Taskflow
ddp = Taskflow("dependency_parsing")
ddp("三亚是一座美丽的城市")
'''
[{'word': ['三亚', '是', '一座', '美丽', '的', '城市'], 'head': [2, 0, 6, 6, 4, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'MT', 'VOB']}]
'''
ddp(["三亚是一座美丽的城市", "他送了一本书"])
'''
[{'word': ['三亚', '是', '一座', '美丽', '的', '城市'], 'head': [2, 0, 6, 6, 4, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'MT', 'VOB']}, {'word': ['他', '送', '了', '一本', '书'], 'head': [2, 0, 2, 5, 2], 'deprel': ['SBV', 'HED', 'MT', 'ATT', 'VOB']}]
'''
ddp = Taskflow("dependency_parsing", prob=True, use_pos=True)
ddp("三亚是一座美丽的城市")
'''
[{'word': ['三亚', '是', '一座', '美丽的城市'], 'head': [2, 0, 4, 2], 'deprel': ['SBV', 'HED', 'ATT', 'VOB'], 'postag': ['LOC', 'v', 'm', 'n'], 'prob': [1.0, 1.0, 1.0, 1.0]}]
'''
ddp = Taskflow("dependency_parsing", model="ddparser-ernie-1.0")
ddp("三亚是一座美丽的城市")
'''
[{'word': ['三亚', '是', '一座', '美丽', '的', '城市'], 'head': [2, 0, 6, 6, 4, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'MT', 'VOB']}]
'''
ddp = Taskflow("dependency_parsing", model="ddparser-ernie-gram-zh")
ddp("三亚是一座美丽的城市")
'''
[{'word': ['三亚', '是', '一座', '美丽', '的', '城市'], 'head': [2, 0, 6, 6, 4, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'MT', 'VOB']}]
'''
# 已分词输入
ddp = Taskflow("dependency_parsing", segmented=True)
ddp.from_segments([["三亚", "是", "一座", "美丽", "的", "城市"]])
'''
[{'word': ['三亚', '是', '一座', '美丽', '的', '城市'], 'head': [2, 0, 6, 6, 4, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'MT', 'VOB']}]
'''
ddp.from_segments([['三亚', '是', '一座', '美丽', '的', '城市'], ['他', '送', '了', '一本', '书']])
'''
[{'word': ['三亚', '是', '一座', '美丽', '的', '城市'], 'head': [2, 0, 6, 6, 4, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'MT', 'VOB']}, {'word': ['他', '送', '了', '一本', '书'], 'head': [2, 0, 2, 5, 2], 'deprel': ['SBV', 'HED', 'MT', 'ATT', 'VOB']}]
'''
"""
[docs]
class DDParserTask(Task):
"""
DDParser task to analyze the dependency relationship between words in a sentence
Args:
task(string): The name of task.
model(string): The model name in the task.
tree(bool): Ensure the output conforms to the tree structure.
prob(bool): Whether to return the probability of predicted heads.
use_pos(bool): Whether to return the postag.
batch_size(int): Numbers of examples a batch.
return_visual(bool): If True, the result will contain the dependency visualization.
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
"""
resource_files_names = {
"model_state": "model_state.pdparams",
"word_vocab": "word_vocab.json",
"rel_vocab": "rel_vocab.json",
}
resource_files_urls = {
"ddparser": {
"model_state": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/ddparser/model_state.pdparams",
"f388c91e85b5b4d0db40157a4ee28c08",
],
"word_vocab": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/ddparser/word_vocab.json",
"594694033b149cbb724cac0975df07e4",
],
"rel_vocab": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/ddparser/rel_vocab.json",
"0decf1363278705f885184ff8681f4cd",
],
},
"ddparser-ernie-1.0": {
"model_state": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/ddparser-ernie-1.0/model_state.pdparams",
"78a4d5c2add642a88f6fdbee3574f617",
],
"word_vocab": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/ddparser-ernie-1.0/word_vocab.json",
"17ed37b5b7ebb8475d4bff1ff8dac4b7",
],
"rel_vocab": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/ddparser-ernie-1.0/rel_vocab.json",
"0decf1363278705f885184ff8681f4cd",
],
},
"ddparser-ernie-gram-zh": {
"model_state": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/ddparser-ernie-gram-zh/model_state.pdparams",
"9d0a49026feb97fac22c8eec3e88f5c3",
],
"word_vocab": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/ddparser-ernie-gram-zh/word_vocab.json",
"38120123d39876337975cc616901c8b9",
],
"rel_vocab": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/ddparser-ernie-gram-zh/rel_vocab.json",
"0decf1363278705f885184ff8681f4cd",
],
},
"font_file": {
"font_file": [
"https://bj.bcebos.com/paddlenlp/taskflow/dependency_parsing/SourceHanSansCN-Regular.ttf",
"cecb7328bc0b9412b897fb3fc61edcdb",
]
},
}
def __init__(
self,
task,
model,
tree=True,
prob=False,
use_pos=False,
use_cuda=False,
batch_size=1,
return_visual=False,
**kwargs
):
super().__init__(task=task, model=model, **kwargs)
self._usage = usage
self.model = model
if self.model == "ddparser":
self.encoding_model = "lstm-pe"
elif self.model == "ddparser-ernie-1.0":
self.encoding_model = "ernie-1.0"
elif self.model == "ddparser-ernie-gram-zh":
self.encoding_model = "ernie-gram-zh"
else:
raise ValueError(
"The encoding model should be one of \
ddparser, ddparser-ernie-1.0 and ddparser-ernie-gram-zh"
)
self._check_task_files()
self._construct_vocabs()
self.font_file_path = download_file(
self._task_path,
"SourceHanSansCN-Regular.ttf",
self.resource_files_urls["font_file"]["font_file"][0],
self.resource_files_urls["font_file"]["font_file"][1],
)
self.tree = tree
self.prob = prob
self.use_pos = use_pos
self.batch_size = batch_size
self.return_visual = return_visual
try:
from LAC import LAC
except Exception:
raise ImportError("Please install the dependencies first, pip install LAC --upgrade")
self.use_cuda = use_cuda
self.lac = LAC(mode="lac" if self.use_pos else "seg", use_cuda=self.use_cuda)
self._get_inference_model()
def _check_segmented_words(self, inputs):
inputs = inputs[0]
if not all([isinstance(i, list) and i and all(i) for i in inputs]):
raise TypeError("Invalid input format.")
return inputs
def from_segments(self, segmented_words):
# pos tag is not available for segmented inputs
self.use_pos = False
segmented_words = self._check_segmented_words(segmented_words)
inputs = {}
inputs["words"] = segmented_words
inputs = self._preprocess_words(inputs)
outputs = self._run_model(inputs)
results = self._postprocess(outputs)
return results
def _construct_input_spec(self):
"""
Construct the input spec for the convert dygraph model to static model.
"""
self._input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype="int64"),
paddle.static.InputSpec(shape=[None, None], dtype="int64"),
]
def _construct_vocabs(self):
word_vocab_path = os.path.join(self._task_path, "word_vocab.json")
rel_vocab_path = os.path.join(self._task_path, "rel_vocab.json")
self.word_vocab = Vocab.from_json(word_vocab_path)
self.rel_vocab = Vocab.from_json(rel_vocab_path)
self.word_pad_index = self.word_vocab.to_indices("[PAD]")
self.word_bos_index = self.word_vocab.to_indices("[CLS]")
self.word_eos_index = self.word_vocab.to_indices("[SEP]")
def _construct_model(self, model):
"""
Construct the inference model for the predictor.
"""
model_instance = BiAffineParser(
encoding_model=self.encoding_model,
n_rels=len(self.rel_vocab),
n_words=len(self.word_vocab),
pad_index=self.word_pad_index,
bos_index=self.word_bos_index,
eos_index=self.word_eos_index,
)
model_path = os.path.join(self._task_path, "model_state.pdparams")
# Load the model parameter for the predict
state_dict = paddle.load(model_path)
model_instance.set_dict(state_dict)
model_instance.eval()
self._model = model_instance
def _construct_tokenizer(self, model):
"""
Construct the tokenizer for the predictor.
"""
return None
def _preprocess_words(self, inputs):
examples = []
for text in inputs["words"]:
example = {"FORM": text}
example = convert_example(example, vocabs=[self.word_vocab, self.rel_vocab])
examples.append(example)
batches = [examples[idx : idx + self.batch_size] for idx in range(0, len(examples), self.batch_size)]
def batchify_fn(batch):
raw_batch = [raw for raw in zip(*batch)]
batch = [pad_sequence(data) for data in raw_batch]
return batch
batches = [flat_words(batchify_fn(batch)[0]) for batch in batches]
inputs["data_loader"] = batches
return inputs
def _preprocess(self, inputs):
"""
Transform the raw text to the model inputs, two steps involved:
1) Transform the raw text to token ids.
2) Generate the other model inputs from the raw text and token ids.
"""
outputs = {}
lac_results = []
position = 0
inputs = self._check_input_text(inputs)
while position < len(inputs):
lac_results += self.lac.run(inputs[position : position + self.batch_size])
position += self.batch_size
if not self.use_pos:
outputs["words"] = lac_results
else:
outputs["words"], outputs["postags"] = [raw for raw in zip(*lac_results)]
outputs = self._preprocess_words(outputs)
return outputs
def _run_model(self, inputs):
"""
Run the task model from the outputs of the `_tokenize` function.
"""
arcs, rels, probs = [], [], []
for batch in inputs["data_loader"]:
words, wp = batch
self.input_handles[0].copy_from_cpu(words)
self.input_handles[1].copy_from_cpu(wp)
self.predictor.run()
arc_preds = self.output_handle[0].copy_to_cpu()
rel_preds = self.output_handle[1].copy_to_cpu()
s_arc = self.output_handle[2].copy_to_cpu()
mask = self.output_handle[3].copy_to_cpu().astype("bool")
arc_preds, rel_preds = decode(arc_preds, rel_preds, s_arc, mask, self.tree)
arcs.extend([arc_pred[m] for arc_pred, m in zip(arc_preds, mask)])
rels.extend([rel_pred[m] for rel_pred, m in zip(rel_preds, mask)])
if self.prob:
arc_probs = probability(s_arc, arc_preds)
probs.extend([arc_prob[m] for arc_prob, m in zip(arc_probs, mask)])
inputs["arcs"] = arcs
inputs["rels"] = rels
inputs["probs"] = probs
return inputs
def _postprocess(self, inputs):
arcs = inputs["arcs"]
rels = inputs["rels"]
words = inputs["words"]
arcs = [[s.item() for s in seq] for seq in arcs]
rels = [self.rel_vocab.to_tokens(seq) for seq in rels]
results = []
for word, arc, rel in zip(words, arcs, rels):
result = {
"word": word,
"head": arc,
"deprel": rel,
}
results.append(result)
if self.use_pos:
postags = inputs["postags"]
for result, postag in zip(results, postags):
result["postag"] = postag
if self.prob:
probs = inputs["probs"]
probs = [[round(p, 2) for p in seq.tolist()] for seq in probs]
for result, prob in zip(results, probs):
result["prob"] = prob
if self.return_visual:
for result in results:
result["visual"] = self._visualize(result)
return results
def _visualize(self, data):
"""
Visualize the dependency.
Args:
data(dict): A dict contains the word, head and dep
Returns:
data: a numpy array, use cv2.imshow to show it or cv2.imwrite to save it.
"""
try:
import matplotlib.font_manager as font_manager
import matplotlib.pyplot as plt
except Exception:
raise ImportError("Please install the dependencies first, pip install matplotlib --upgrade")
self.plt = plt
self.font = font_manager.FontProperties(fname=self.font_file_path)
word, head, deprel = data["word"], data["head"], data["deprel"]
nodes = ["ROOT"] + word
x = list(range(len(nodes)))
y = [0] * (len(nodes))
fig, ax = self.plt.subplots()
# Control the picture size
max_span = max([abs(i + 1 - j) for i, j in enumerate(head)])
fig.set_size_inches((len(nodes), max_span / 2))
# Set the points
self.plt.scatter(x, y, c="w")
for i in range(len(nodes)):
txt = nodes[i]
xytext = (i, 0)
if i == 0:
# Set 'ROOT'
ax.annotate(
txt,
xy=xytext,
xycoords="data",
xytext=xytext,
textcoords="data",
)
else:
xy = (head[i - 1], 0)
rad = 0.5 if head[i - 1] < i else -0.5
# Set the word
ax.annotate(
txt,
xy=xy,
xycoords="data",
xytext=(xytext[0] - 0.1, xytext[1]),
textcoords="data",
fontproperties=self.font,
)
# Draw the curve
ax.annotate(
"",
xy=xy,
xycoords="data",
xytext=xytext,
textcoords="data",
arrowprops=dict(
arrowstyle="<-",
shrinkA=12,
shrinkB=12,
color="blue",
connectionstyle="arc3,rad=%s" % rad,
),
)
# Set the deprel label. Calculate its position by the radius
text_x = min(i, head[i - 1]) + abs((i - head[i - 1])) / 2 - 0.2
text_y = abs((i - head[i - 1])) / 4
ax.annotate(deprel[i - 1], xy=xy, xycoords="data", xytext=[text_x, text_y], textcoords="data")
# Control the axis
self.plt.axis("equal")
self.plt.axis("off")
# Save to numpy array
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))[:, :, ::-1]
return data
[docs]
def pad_sequence(sequences, padding_value=0, fix_len=None):
"""Fill sequences(np.ndarray) into a fixed-length matrix."""
max_size = sequences[0].shape
trailing_dims = max_size[1:]
max_len = max([s.shape[0] for s in sequences])
if fix_len is not None:
assert fix_len >= max_len, "fix_len is too small."
max_len = fix_len
out_dims = (len(sequences), max_len) + trailing_dims
out_tensor = np.full(out_dims, padding_value, dtype=sequences[0].dtype)
for i, tensor in enumerate(sequences):
length = tensor.shape[0]
out_tensor[i, :length, ...] = tensor
return out_tensor
def convert_example(example, vocabs, fix_len=20):
word_vocab, rel_vocab = vocabs
word_bos_index = word_vocab.to_indices("[CLS]")
word_eos_index = word_vocab.to_indices("[SEP]")
words = [[word_vocab.to_indices(char) for char in word] for word in example["FORM"]]
words = [[word_bos_index]] + words + [[word_eos_index]]
return [pad_sequence([np.array(ids[:fix_len], dtype=np.int64) for ids in words], fix_len=fix_len)]
def flat_words(words, pad_index=0):
mask = words != pad_index
lens = np.sum(mask.astype(np.int64), axis=-1)
position = np.cumsum(lens + (lens == 0).astype(np.int64), axis=1) - 1
lens = np.sum(lens, -1)
words = words.ravel()[np.flatnonzero(words)]
sequences = []
idx = 0
for l in lens:
sequences.append(words[idx : idx + l])
idx += l
words = Pad(pad_val=pad_index)(sequences)
max_len = words.shape[1]
mask = (position >= max_len).astype(np.int64)
position = position * np.logical_not(mask) + mask * (max_len - 1)
return words, position
def probability(s_arc, arc_preds):
s_arc = s_arc - s_arc.max(axis=-1).reshape(list(s_arc.shape)[:-1] + [1])
s_arc = np.exp(s_arc) / np.exp(s_arc).sum(axis=-1).reshape(list(s_arc.shape)[:-1] + [1])
arc_probs = [s[np.arange(len(arc_pred)), arc_pred] for s, arc_pred in zip(s_arc, arc_preds)]
return arc_probs
[docs]
def decode(arc_preds, rel_preds, s_arc, mask, tree):
"""decode"""
lens = np.sum(mask, -1)
bad = [not istree(seq[: i + 1]) for i, seq in zip(lens, arc_preds)]
if tree and any(bad):
arc_preds[bad] = eisner(s_arc[bad], mask[bad])
rel_preds = [rel_pred[np.arange(len(arc_pred)), arc_pred] for arc_pred, rel_pred in zip(arc_preds, rel_preds)]
return arc_preds, rel_preds
[docs]
def eisner(scores, mask):
"""
Eisner algorithm is a general dynamic programming decoding algorithm for bilexical grammar.
Args:
scores: Adjacency matrix,shape=(batch, seq_len, seq_len)
mask: mask matrix,shape=(batch, sql_len)
Returns:
output,shape=(batch, seq_len),the index of the parent node corresponding to the token in the query
"""
lens = mask.sum(1)
batch_size, seq_len, _ = scores.shape
scores = scores.transpose(2, 1, 0)
# Score for incomplete span
s_i = np.full_like(scores, float("-inf"))
# Score for complete span
s_c = np.full_like(scores, float("-inf"))
# Incomplete span position for backtrack
p_i = np.zeros((seq_len, seq_len, batch_size), dtype=np.int64)
# Complete span position for backtrack
p_c = np.zeros((seq_len, seq_len, batch_size), dtype=np.int64)
# Set 0 to s_c.diagonal
s_c = fill_diagonal(s_c, 0)
# Contiguous
s_c = np.ascontiguousarray(s_c)
s_i = np.ascontiguousarray(s_i)
for w in range(1, seq_len):
n = seq_len - w
starts = np.arange(n, dtype=np.int64)[np.newaxis, :]
# ilr = C(i->r) + C(j->r+1)
ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
# Shape: [batch_size, n, w]
ilr = ilr.transpose(2, 0, 1)
# scores.diagonal(-w).shape:[batch, n]
il = ilr + scores.diagonal(-w)[..., np.newaxis]
# I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j
il_span, il_path = il.max(-1), il.argmax(-1)
s_i = fill_diagonal(s_i, il_span, offset=-w)
p_i = fill_diagonal(p_i, il_path + starts, offset=-w)
ir = ilr + scores.diagonal(w)[..., np.newaxis]
# I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j
ir_span, ir_path = ir.max(-1), ir.argmax(-1)
s_i = fill_diagonal(s_i, ir_span, offset=w)
p_i = fill_diagonal(p_i, ir_path + starts, offset=w)
# C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
cl = cl.transpose(2, 0, 1)
cl_span, cl_path = cl.max(-1), cl.argmax(-1)
s_c = fill_diagonal(s_c, cl_span, offset=-w)
p_c = fill_diagonal(p_c, cl_path + starts, offset=-w)
# C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
cr = cr.transpose(2, 0, 1)
cr_span, cr_path = cr.max(-1), cr.argmax(-1)
s_c = fill_diagonal(s_c, cr_span, offset=w)
s_c[0, w][np.not_equal(lens, w)] = float("-inf")
p_c = fill_diagonal(p_c, cr_path + starts + 1, offset=w)
predicts = []
p_c = p_c.transpose(2, 0, 1)
p_i = p_i.transpose(2, 0, 1)
for i, length in enumerate(lens.tolist()):
heads = np.ones(length + 1, dtype=np.int64)
backtrack(p_i[i], p_c[i], heads, 0, length, True)
predicts.append(heads)
return pad_sequence(predicts, fix_len=seq_len)
[docs]
def fill_diagonal(x, value, offset=0, dim1=0, dim2=1):
"""
Fill value into the diagoanl of x that offset is ${offset}
and the coordinate system is (dim1, dim2).
"""
strides = x.strides
shape = x.shape
if dim1 > dim2:
dim1, dim2 = dim2, dim1
assert 0 <= dim1 < dim2 <= 2
assert len(x.shape) == 3
assert shape[dim1] == shape[dim2]
dim_sum = dim1 + dim2
dim3 = 3 - dim_sum
if offset >= 0:
diagonal = np.lib.stride_tricks.as_strided(
x[:, offset:] if dim_sum == 1 else x[:, :, offset:],
shape=(shape[dim3], shape[dim1] - offset),
strides=(strides[dim3], strides[dim1] + strides[dim2]),
)
else:
diagonal = np.lib.stride_tricks.as_strided(
x[-offset:, :] if dim_sum in [1, 2] else x[:, -offset:],
shape=(shape[dim3], shape[dim1] + offset),
strides=(strides[dim3], strides[dim1] + strides[dim2]),
)
diagonal[...] = value
return x
[docs]
def backtrack(p_i, p_c, heads, i, j, complete):
"""
Backtrack the position matrix of eisner to generate the tree
"""
if i == j:
return
if complete:
r = p_c[i, j]
backtrack(p_i, p_c, heads, i, r, False)
backtrack(p_i, p_c, heads, r, j, True)
else:
r, heads[j] = p_i[i, j], i
i, j = sorted((i, j))
backtrack(p_i, p_c, heads, i, r, True)
backtrack(p_i, p_c, heads, j, r + 1, True)
[docs]
def stripe(x, n, w, offset=(0, 0), dim=1):
"""
Returns a diagonal stripe of the tensor.
Args:
x (Tensor): the input tensor with 2 or more dims.
n (int): the length of the stripe.
w (int): the width of the stripe.
offset (tuple): the offset of the first two dims.
dim (int): 0 if returns a horizontal stripe; 1 else.
Example:
>>> x = np.arange(25).reshape(5, 5)
>>> x
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]])
>>> stripe(x, 2, 3, (1, 1))
tensor([[ 6, 7, 8],
[12, 13, 14]])
>>> stripe(x, 2, 3, dim=0)
tensor([[ 0, 5, 10],
[ 6, 11, 16]])
"""
if not x.flags["C_CONTIGUOUS"]:
x = np.ascontiguousarray(x)
strides = x.strides
m = strides[0] + strides[1]
k = strides[1] if dim == 1 else strides[0]
return np.lib.stride_tricks.as_strided(
x[offset[0] :, offset[1] :], shape=[n, w] + list(x.shape[2:]), strides=[m, k] + list(strides[2:])
)
[docs]
class Node:
"""Node class"""
def __init__(self, id=None, parent=None):
self.lefts = []
self.rights = []
self.id = int(id)
self.parent = parent if parent is None else int(parent)
[docs]
class DepTree:
"""
DepTree class, used to check whether the prediction result is a project Tree.
A projective tree means that you can project the tree without crossing arcs.
"""
def __init__(self, sentence):
# set root head to -1
sentence = copy.deepcopy(sentence)
sentence[0] = -1
self.sentence = sentence
self.build_tree()
self.visit = [False] * len(sentence)
[docs]
def build_tree(self):
"""Build the tree"""
self.nodes = [Node(index, p_index) for index, p_index in enumerate(self.sentence)]
# set root
self.root = self.nodes[0]
for node in self.nodes[1:]:
self.add(self.nodes[node.parent], node)
[docs]
def add(self, parent, child):
"""Add a child node"""
if parent.id is None or child.id is None:
raise Exception("id is None")
if parent.id < child.id:
parent.rights = sorted(parent.rights + [child.id])
else:
parent.lefts = sorted(parent.lefts + [child.id])
[docs]
def judge_legal(self):
"""Determine whether it is a project tree"""
target_seq = list(range(len(self.nodes)))
if len(self.root.lefts + self.root.rights) != 1:
return False
cur_seq = self.inorder_traversal(self.root)
if target_seq != cur_seq:
return False
else:
return True
[docs]
def inorder_traversal(self, node):
"""Inorder traversal"""
if self.visit[node.id]:
return []
self.visit[node.id] = True
lf_list = []
rf_list = []
for ln in node.lefts:
lf_list += self.inorder_traversal(self.nodes[ln])
for rn in node.rights:
rf_list += self.inorder_traversal(self.nodes[rn])
return lf_list + [node.id] + rf_list
[docs]
def istree(sequence):
"""Is the sequence a project tree"""
return DepTree(sequence).judge_legal()