From 91a4426d4be1718762b21b920dbc0a0ab2b81a3e Mon Sep 17 00:00:00 2001 From: telescopic Date: Sat, 10 Oct 2020 03:09:39 +0530 Subject: [PATCH] Fix: handle word not in vocabulary error --- incorrect_answer_generation.py | 35 +++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/incorrect_answer_generation.py b/incorrect_answer_generation.py index bdb75a9..b1e0f4f 100644 --- a/incorrect_answer_generation.py +++ b/incorrect_answer_generation.py @@ -2,10 +2,12 @@ for generating incorrect alternative answers for a given answer ''' +import gensim import gensim.downloader as api +from gensim.models import Word2Vec from nltk.tokenize import sent_tokenize, word_tokenize import random - +import numpy as np class IncorrectAnswerGenerator: ''' This class contains the methods @@ -13,20 +15,43 @@ class IncorrectAnswerGenerator: given an answer ''' - def __init__(self): + def __init__(self, document): # model required to fetch similar words self.model = api.load("glove-wiki-gigaword-100") + self.all_words = [] + for sent in sent_tokenize(document): + self.all_words.extend(word_tokenize(sent)) + self.all_words = list(set(self.all_words)) def get_all_options_dict(self, answer, num_options): ''' This method returns a dict of 'num_options' options out of which one is correct and is the answer ''' - similar_words = self.model.similar_by_word(answer, topn=15)[::-1] options_dict = dict() + try: + similar_words = self.model.similar_by_word(answer, topn=15)[::-1] - for i in range(1, num_options + 1): - options_dict[i] = similar_words[i - 1][0] + for i in range(1, num_options + 1): + options_dict[i] = similar_words[i - 1][0] + + except: + self.all_sim = [] + for word in self.all_words: + if word not in answer: + try: + self.all_sim.append( + (self.model.similarity(answer, word), word)) + except: + self.all_sim.append( + (0.0, word)) + else: + self.all_sim.append((-1.0, word)) + + self.all_sim.sort(reverse=True) + + for i in range(1, num_options+1): + options_dict[i] = self.all_sim[i-1][1] replacement_idx = random.randint(1, num_options)