From 4f63d3ccd4d7058df516f49f51533fa80cad52dd Mon Sep 17 00:00:00 2001 From: Gijs Hendriksen Date: Mon, 16 Dec 2019 16:01:44 +0100 Subject: [PATCH] Add assertion that scores are equal across engines --- index.py | 5 +++-- main.py | 18 +++++++++++++++--- query.py | 2 +- search.py | 1 + 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/index.py b/index.py index 3d279ad..0b5736f 100644 --- a/index.py +++ b/index.py @@ -25,7 +25,7 @@ class Index(ABC): for term in word_tokenize(body.lower()): if term not in self.stopwords and term.isalpha(): - terms[term[:32]] += 1 + terms[term] += 1 return terms @@ -208,7 +208,8 @@ class DuckDBIndex(Index): def search(self, query): self.cursor.execute(query) - return self.cursor.fetchdf() + df = self.cursor.fetchdf() + return list(df.itertuples(index=False, name=None))[:10] def clear(self): self.cursor.execute("DELETE FROM terms") diff --git a/main.py b/main.py index 5d33dfe..b6208c6 100755 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import argparse +import math import time from index import Index, DuckDBIndex, MonetDBIndex from search import Search @@ -40,28 +41,34 @@ def benchmark(args: argparse.Namespace): iterations = 20 + scores = [[] for _ in range(len(indices))] + for filename in args.input: benchmark_times = [] print(f'Filename: "{filename}"') - for index in indices: + for i, index in enumerate(indices): index.clear() print('Indexing...') index.bulk_index(filename) + search = Search(index) + times = [] for query in queries: start = time.time() for _ in range(iterations): - search = Search(index) search.search(query) end = time.time() avg_time = (end - start) / iterations - times.append(f'{avg_time:.04}s') + times.append(f'{avg_time:.4f}s') + + # Compare the scores to verify both engines return the same results + scores[i].append(search.search(query)) benchmark_times.append(times) @@ -74,6 +81,11 @@ def benchmark(args: argparse.Namespace): print() + for i in range(len(scores[0])): + for duck_scores, monet_scores in zip(scores[0][i], scores[1][i]): + assert duck_scores[0] == monet_scores[0], 'Retrieved documents are not equal!' + assert math.isclose(duck_scores[1], monet_scores[1], abs_tol=1e-2), f'Scores are unequal: {duck_scores[1]}, {monet_scores[1]}' + def dump_index(args: argparse.Namespace): index = Index.get_index(args.engine, args.database) diff --git a/query.py b/query.py index 4d1e471..a7a139f 100644 --- a/query.py +++ b/query.py @@ -16,7 +16,7 @@ def bm25(terms, disjunctive=True): AS cdocs ON term_tf.docid = cdocs.docid JOIN docs ON term_tf.docid=docs.docid JOIN dict ON term_tf.termid=dict.termid) - SELECT scores.docid, score FROM (SELECT docid, sum(subscore) AS score + SELECT docs.name, score FROM (SELECT docid, sum(subscore) AS score FROM subscores GROUP BY docid) AS scores JOIN docs ON scores.docid=docs.docid ORDER BY score DESC; """ diff --git a/search.py b/search.py index 1ea9125..e31096a 100644 --- a/search.py +++ b/search.py @@ -9,6 +9,7 @@ class Search: self.index = index def search(self, terms, method='bm25'): + terms = self.index.get_terms(' '.join(terms)).keys() if method == 'bm25': sql_query = query.bm25(terms) else: -- GitLab