def train(self, corpus):
for sentence in corpus:
# 首先检查句子合法性,去掉全部是大写的句子
if not self.check_sentence_sanity(sentence):
continue
for word_idx, word in enumerate(sentence):
self.uni_dist[word] += 1
word_lower = word.lower()
# 把单词的所有大小写可能放到word_casing_lookup中
if word_lower not in self.word_casing_lookup:
self.word_casing_lookup[word_lower] = set()
self.word_casing_lookup[word_lower].add(word)
# 将word的上下文加入bi-gram语言模型统计中
self.__function_one(sentence, word, word_idx, word_lower)
# 将word的上下文加入tri-gram语言模型统计中
self.__function_two(sentence, word, word_idx)
训练好的模型用下面代码存放到pickle中:
def save_to_file(self, file_path):
pickle_dict = {
“uni_dist”: self.uni_dist,
“backward_bi_dist”: self.backward_bi_dist,
“forward_bi_dist”: self.forward_bi_dist,
“trigram_dist”: self.trigram_dist,
“word_casing_lookup”: self.word_casing_lookup,
}
with open(file_path, “wb”) as fp:
pickle.dump(pickle_dict, fp)
print(“Model saved to ” + file_path)
模型推理的主函数:
def get_true_case(self, sentence, out_of_vocabulary_token_option=”title”):
# outOfVocabulariyTokenOption=title将OOV词以大写格式输出
# outOfVocabulariyTokenOption=lower将OOV词以小写格式输出
# outOfVocabulariyTokenOption=as-is将OOV词以原先格式输出
tokens = self.tknzr.tokenize(sentence)
tokens_true_case = []
for token_idx, token in enumerate(tokens):
# 标点和数字原样输出
if token in string.punctuation or token.isdigit():
tokens_true_case.append(token)
else:
token = token.lower()
# 在词表中
if token in self.word_casing_lookup:
# 只有一种形式的直接返回
if len(self.word_casing_lookup[token]) == 1:
tokens_true_case.append(
list(self.word_casing_lookup[token])[0])
else:
prev_token = (tokens_true_case[token_idx – 1]
if token_idx > 0 else None)
next_token = (tokens[token_idx + 1]
if token_idx < len(tokens) – 1 else None)
best_token = None
highest_score = float(“-inf”)
# 找到语言模型得分最高的得分组合
for possible_token in self.word_casing_lookup[token]:
score = self.get_score(prev_token, possible_token, next_token)
if score > highest_score:
best_token = possible_token
highest_score = score
tokens_true_case.append(best_token)
if token_idx == 0:
tokens_true_case[0] = tokens_true_case[0].title()
else: # OOV
if out_of_vocabulary_token_option == “title”:
tokens_true_case.append(token.title())
elif out_of_vocabulary_token_option == “lower”:
tokens_true_case.append(token.lower())
else:
tokens_true_case.append(token)
return “”.join([
” ” +
i if not i.startswith(“”) and i not in string.punctuation else i
for i in tokens_true_case
]).strip()