-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
58 lines (48 loc) · 1.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import csv
import sys
import tqdm
import typing as T
from config import *
from pipeline.metric import set_corpus
def main() -> int:
queries = sys.argv[1:] if len(sys.argv) > 1 else QUERIES
header = table_header(SOURCES, SOURCE_TYPE, METRICS)
os.makedirs(RESULT_DIR, exist_ok=True)
with open(os.path.join(RESULT_DIR, 'results.csv'), 'w') as f:
writer = csv.DictWriter(f, header)
writer.writeheader()
rows = table_rows(queries, SOURCES, SOURCE_TYPE, METRICS, WRITERS)
writer.writerows(rows)
return 0
def table_header(sources: SourceDictType,
source_type: SourceTypeType,
metrics: Metrics) -> T.List[T.List[str]]:
header = ['query']
for source_name, source in sources.items():
source_output_type = source_type[source_name]
for metric in metrics[source_output_type]:
header.append(f'{source_name}_{metric}')
return header
def table_rows(queries: T.List[str],
sources: SourceDictType,
source_type: SourceTypeType,
metrics: Metrics,
writers: WriterTypes) -> T.List[T.List[str]]:
rows = {}
for query in queries:
rows[query] = {'query': query}
for source_name, source in sources.items():
os.makedirs(os.path.join(CORPUS_DIR, source_name), exist_ok=True)
source_output_type = source_type[source_name]
writer = writers[source_output_type]
print("Building corpus from source {}".format(source_name))
corpus = [source(query) for query in tqdm.tqdm(queries)]
set_corpus(corpus)
for query, response in zip(queries, corpus):
writer(response, os.path.join(CORPUS_DIR, source_name, f'{query}'))
for metric_name, metric in metrics[source_output_type].items():
rows[query][f'{source_name}_{metric_name}'] = metric(response)
return list(rows.values())
if __name__ == "__main__":
sys.exit(main())