Skip to content
Snippets Groups Projects
Commit c8826774 authored by Adam Pawlaczek's avatar Adam Pawlaczek
Browse files

Fixed double heads

parent 94f441d3
Branches
No related merge requests found
......@@ -109,13 +109,13 @@ class Chunker:
self.conf_dir, self.data_dir, self.layers)
self.layer_models = None # layer_name -> trained classifier
self.stats = Stats()
def load_model(self):
self.layer_models = {}
for layer in self.layers:
self.layer_models[layer] = classify.load(
self.conf, self.model_name, self.data_dir, layer)
def train_and_save(self, in_path, input_format):
"""Trains the chunker and stores the model to files beginning with
model_name."""
......@@ -125,19 +125,19 @@ class Chunker:
# open files for storing training examples for each layer
tr_files = classify.open_tr_files(
self.model_name, self.data_dir, self.layers)
# set-up the reader and gather feature values for subsequent sentences
reader = corpio.get_reader(
in_path, self.tagset, input_format, self.is_input_tagged)
self.stats.clear()
while True:
sent = reader.get_next_sentence()
if not sent:
break # end of input
# wrap the sentence as an AnnotatedSentence
asent = corpus2.AnnotatedSentence.wrap_sentence(sent)
# iterate over layers
for layer_idx, layer in enumerate(self.layers):
chans = self.layer_channels[layer_idx]
......@@ -194,7 +194,7 @@ class Chunker:
if self.verbose:
sys.stderr.write('done!\n')
self.stats.dump()
def tag_sentence(self, sent):
"""Chunks the given sentence."""
# wrap the sentence as an AnnotatedSentence
......@@ -223,6 +223,8 @@ class Chunker:
for op in self.layer_ops[layer_idx]]
classify.eat_token(model, feat_vals)
classify.close_sent(model)
last_iobs = {}
for tok_idx, tok in enumerate(sent.tokens()):
decsn = classify.classify_token(model, tok_idx)
non_O_chan = None
......@@ -238,20 +240,23 @@ class Chunker:
raise IOError('Unexpected label returned from classifier: ' + decsn)
for chan_name in chans:
chan = asent.get_channel(chan_name)
# TODO: rename the from_string in corpus2 and fix it here
tag_to_set = 'O' if chan_name != non_O_chan else non_O_tag
chan.set_iob_at(tok_idx, corpus2.from_string(tag_to_set))
if tag_to_set != 'O' and is_head:
if tag_to_set == "I" and (not last_iobs.has_key(chan_name) or last_iobs[chan_name] == "O"):
tag_to_set = 'B'
if tag_to_set == 'B':
head_idx = None
seg_no = chan.get_new_segment_index()
if tag_to_set in "BI":
chan.set_segment_at(tok_idx, seg_no)
last_iobs[chan_name] = tag_to_set
if tag_to_set != 'O' and is_head and head_idx == None:
chan.set_head_at(tok_idx, True)
# switch back to segments
for chan_name in chans:
chan = asent.get_channel(chan_name)
chan.make_segments_from_iob()
head_idx = tok_idx
self.stats.num_sents += 1
self.stats.num_toks += sent.tokens().size()
if self.verbose: self.stats.maybe_report()
def tag_input(self, in_path, out_path, input_format, output_format,
preserve_pars):
"""Chunks the input and writes processed input to out_path or stdout if
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment