Skip to content

Commit 7e2f977

Browse files
authored
Add Algolia Support (#105)
* Add Algolia Support Signed-off-by: Vivek Joshy <8206808+vivekjoshy@users.noreply.github.com> * Add Changelog Fragment Signed-off-by: Vivek Joshy <8206808+vivekjoshy@users.noreply.github.com> --------- Signed-off-by: Vivek Joshy <8206808+vivekjoshy@users.noreply.github.com>
1 parent b7b4918 commit 7e2f977

File tree

27 files changed

+493
-492
lines changed

27 files changed

+493
-492
lines changed

benchmark/benchmark.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from processors import Draw, Rank, Win
21
from prompt_toolkit import HTML
32
from prompt_toolkit import print_formatted_text as print
43
from prompt_toolkit import prompt
@@ -12,6 +11,7 @@
1211
ThurstoneMostellerFull,
1312
ThurstoneMostellerPart,
1413
)
14+
from processors import Draw, Rank, Win
1515

1616

1717
class NumberValidator(Validator):
@@ -48,6 +48,13 @@ def validate(self, document):
4848
input_benchmark_type = prompt(
4949
"Benchmark Processor: ", completer=benchmark_types_completer
5050
)
51+
minimum_matches = int(
52+
prompt("Minimum Matches Per Player: ", validator=NumberValidator())
53+
)
54+
if minimum_matches < 1:
55+
print(HTML("<style fg='Red'>Invalid Match Count</style>"))
56+
quit()
57+
5158
input_seed = int(prompt("Enter Random Seed: ", validator=NumberValidator()))
5259

5360
model = None
@@ -59,15 +66,30 @@ def validate(self, document):
5966

6067
if input_benchmark_type in benchmark_type_names.keys() and model:
6168
if input_benchmark_type == "Win":
62-
win_processor = Win(path="data/overwatch.jsonl", seed=input_seed, model=model)
69+
win_processor = Win(
70+
path="data/overwatch.jsonl",
71+
seed=input_seed,
72+
minimum_matches=minimum_matches,
73+
model=model,
74+
)
6375
win_processor.process()
6476
win_processor.print_result()
6577
elif input_benchmark_type == "Draw":
66-
draw_processor = Draw(path="data/chess.csv", seed=input_seed, model=model)
78+
draw_processor = Draw(
79+
path="data/chess.csv",
80+
seed=input_seed,
81+
minimum_matches=minimum_matches,
82+
model=model,
83+
)
6784
draw_processor.process()
6885
draw_processor.print_result()
6986
elif input_benchmark_type == "Rank":
70-
rank_processor = Rank(path="data/overwatch.jsonl", seed=input_seed, model=model)
87+
rank_processor = Rank(
88+
path="data/overwatch.jsonl",
89+
seed=input_seed,
90+
minimum_matches=minimum_matches,
91+
model=model,
92+
)
7193
rank_processor.process()
7294
rank_processor.print_result()
7395
else:

benchmark/processors/draw.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from prompt_toolkit import HTML
1010
from prompt_toolkit import print_formatted_text as print
1111
from prompt_toolkit.shortcuts import ProgressBar
12+
from prompt_toolkit.styles import Style
1213
from sklearn.model_selection import train_test_split
1314

14-
import openskill
1515
from openskill.models import (
1616
BradleyTerryFull,
1717
BradleyTerryPart,
@@ -46,6 +46,7 @@ def __init__(
4646
self,
4747
path,
4848
seed: int,
49+
minimum_matches: int,
4950
model: Union[
5051
BradleyTerryFull,
5152
BradleyTerryPart,
@@ -60,11 +61,13 @@ def __init__(
6061
for match_index, row in df.iterrows():
6162
self.data.append(row)
6263
self.seed = seed
64+
self.minimum_matches = minimum_matches
6365
self.model = model
6466

6567
# Counters
6668
self.match_count = {}
67-
self.confident_matches = 0
69+
self.available_matches = 0
70+
self.valid_matches = 0
6871
self.openskill_correct_predictions = 0
6972
self.openskill_incorrect_predictions = 0
7073
self.trueskill_correct_predictions = 0
@@ -87,15 +90,24 @@ def __init__(
8790
self.trueskill_time = None
8891

8992
def process(self):
93+
style = Style.from_dict(
94+
{
95+
"label": "bg:#ffff00 #000000",
96+
"percentage": "bg:#ffff00 #000000",
97+
"current": "#448844",
98+
"bar": "",
99+
}
100+
)
101+
90102
title = HTML(f'<style fg="Red">Counting Matches</style>')
91-
with ProgressBar(title=title) as progress_bar:
103+
with ProgressBar(title=title, style=style) as progress_bar:
92104
for match in progress_bar(self.data, total=len(self.data)):
93105
if self.consistent(match=match):
94106
self.count(match=match)
95107

96108
# Check if data has sufficient history.
97109
title = HTML(f'<style fg="Red">Verifying History</style>')
98-
with ProgressBar(title=title) as progress_bar:
110+
with ProgressBar(title=title, style=style) as progress_bar:
99111
for match in progress_bar(self.data, total=len(self.data)):
100112
if self.consistent(match=match):
101113
if self.has_sufficient_history(match=match):
@@ -111,7 +123,7 @@ def process(self):
111123
title = HTML(
112124
f'Updating OpenSkill Ratings with <style fg="Green">{self.model.__name__}</style> Model:'
113125
)
114-
with ProgressBar(title=title) as progress_bar:
126+
with ProgressBar(title=title, style=style) as progress_bar:
115127
os_process_time_start = time.time()
116128
for match in progress_bar(self.training_set, total=len(self.training_set)):
117129
self.process_openskill(match=match)
@@ -122,7 +134,7 @@ def process(self):
122134
title = HTML(
123135
f'Updating Ratings with <style fg="Green">TrueSkill</style> Model:'
124136
)
125-
with ProgressBar(title=title) as progress_bar:
137+
with ProgressBar(title=title, style=style) as progress_bar:
126138
ts_process_time_start = time.time()
127139
for match in progress_bar(self.training_set, total=len(self.training_set)):
128140
self.process_trueskill(match=match)
@@ -131,22 +143,23 @@ def process(self):
131143

132144
# Process Test Set
133145
title = HTML(f'<style fg="Red">Processing Test Set</style>')
134-
with ProgressBar(title=title) as progress_bar:
146+
with ProgressBar(title=title, style=style) as progress_bar:
135147
for match in progress_bar(self.test_set, total=len(self.test_set)):
136148
if self.valid_test(match=match):
137149
self.verified_test_set.append(match)
150+
self.valid_matches += 1
138151

139152
# Predict OpenSkill Matches
140153
title = HTML(f'<style fg="Blue">Predicting OpenSkill Matches:</style>')
141-
with ProgressBar(title=title) as progress_bar:
154+
with ProgressBar(title=title, style=style) as progress_bar:
142155
for match in progress_bar(
143156
self.verified_test_set, total=len(self.verified_test_set)
144157
):
145158
self.predict_openskill(match=match)
146159

147160
# Predict TrueSkill Matches
148161
title = HTML(f'<style fg="Blue">Predicting TrueSkill Matches:</style>')
149-
with ProgressBar(title=title) as progress_bar:
162+
with ProgressBar(title=title, style=style) as progress_bar:
150163
for match in progress_bar(
151164
self.verified_test_set, total=len(self.verified_test_set)
152165
):
@@ -157,9 +170,10 @@ def print_result(self):
157170
print("-" * 40)
158171
print(
159172
HTML(
160-
f"Confident Matches: <style fg='Yellow'>{self.confident_matches}</style>"
173+
f"Available Matches: <style fg='Yellow'>{self.available_matches}</style>"
161174
)
162175
)
176+
print(HTML(f"Valid Matches: <style fg='Yellow'>{self.valid_matches}</style>"))
163177
print(
164178
HTML(
165179
f"Predictions Made with OpenSkill's <style fg='Green'><u>{self.model.__name__}</u></style> Model:"
@@ -278,13 +292,13 @@ def has_sufficient_history(self, match):
278292
white_player: dict = match["white_username"]
279293
black_player: dict = match["black_username"]
280294

281-
if self.match_count[white_player] < 2:
295+
if self.match_count[white_player] < self.minimum_matches:
282296
return False
283297

284-
if self.match_count[black_player] < 2:
298+
if self.match_count[black_player] < self.minimum_matches:
285299
return False
286300

287-
self.confident_matches += 1
301+
self.available_matches += 1
288302
return True
289303

290304
def process_openskill(self, match):

0 commit comments

Comments
 (0)