Skip to content

Commit 0e0e89a

Browse files
authored
Merge pull request #27 from OpenDebates/predict-win
Predict win probability of each team.
2 parents d1e3020 + 476f953 commit 0e0e89a

File tree

10 files changed

+62367
-56
lines changed

10 files changed

+62367
-56
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,21 @@ The default model is `PlackettLuce`. You can import alternate models from `opens
9191
[[[17.09430584957905, 7.5012190693964005]], [[32.90569415042095, 7.5012190693964005]], [[22.36476861652635, 7.5012190693964005]], [[27.63523138347365, 7.5012190693964005]]]
9292
```
9393

94+
## Predicting Winners
95+
96+
You can compare two or more teams to get the probabilities of each team winning.
97+
98+
```python
99+
>>> from openskill import predict_win
100+
>>> a1 = Rating()
101+
>>> a2 = Rating(mu=33.564, sigma=1.123)
102+
>>> predictions = predict_win(teams=[[a1], [a2]])
103+
>>> predictions
104+
[0.45110901512761536, 0.5488909848723846]
105+
>>> sum(predictions)
106+
1.0
107+
```
108+
94109
### Available Models
95110
- `BradleyTerryFull`: Full Pairing for Bradley-Terry
96111
- `BradleyTerryPart`: Partial Pairing for Bradely-Terry

benchmark/benchmark.py

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
import itertools
2+
import math
3+
import time
4+
from typing import Union
5+
6+
import jsonlines
7+
import trueskill
8+
from prompt_toolkit import print_formatted_text as print, HTML, prompt
9+
from prompt_toolkit.completion import WordCompleter
10+
from prompt_toolkit.shortcuts import ProgressBar
11+
12+
import openskill
13+
from openskill.models import (
14+
ThurstoneMostellerPart,
15+
ThurstoneMostellerFull,
16+
BradleyTerryFull,
17+
BradleyTerryPart,
18+
PlackettLuce,
19+
)
20+
21+
# Stores
22+
os_players = {}
23+
ts_players = {}
24+
25+
# Counters
26+
os_correct_predictions = 0
27+
os_incorrect_predictions = 0
28+
ts_correct_predictions = 0
29+
ts_incorrect_predictions = 0
30+
31+
32+
print(HTML("<u><b>Benchmark Starting</b></u>"))
33+
34+
35+
def data_verified(match: dict) -> bool:
36+
result = match.get("result")
37+
if result not in ["WIN", "LOSS"]:
38+
return False
39+
40+
teams: dict = match.get("teams")
41+
if list(teams.keys()) != ["blue", "red"]:
42+
return False
43+
44+
blue_team: dict = teams.get("blue")
45+
red_team: dict = teams.get("red")
46+
47+
if len(blue_team) < 1 and len(red_team) < 1:
48+
return False
49+
50+
return True
51+
52+
53+
def process_os_match(
54+
match: dict,
55+
model: Union[
56+
BradleyTerryFull,
57+
BradleyTerryPart,
58+
PlackettLuce,
59+
ThurstoneMostellerFull,
60+
ThurstoneMostellerPart,
61+
] = PlackettLuce,
62+
):
63+
result = match.get("result")
64+
won = True if result == "WIN" else False
65+
66+
teams: dict = match.get("teams")
67+
blue_team: dict = teams.get("blue")
68+
red_team: dict = teams.get("red")
69+
70+
os_blue_players = {}
71+
os_red_players = {}
72+
73+
for player in blue_team:
74+
os_blue_players[player] = openskill.Rating()
75+
76+
for player in red_team:
77+
os_red_players[player] = openskill.Rating()
78+
79+
if won:
80+
blue_team_result, red_team_result = openskill.rate(
81+
[list(os_blue_players.values()), list(os_red_players.values())], model=model
82+
)
83+
else:
84+
red_team_result, blue_team_result = openskill.rate(
85+
[list(os_red_players.values()), list(os_blue_players.values())], model=model
86+
)
87+
88+
blue_team_ratings = [openskill.create_rating(_) for _ in blue_team_result]
89+
red_team_ratings = [openskill.create_rating(_) for _ in red_team_result]
90+
91+
os_blue_players = dict(zip(os_blue_players, blue_team_ratings))
92+
os_red_players = dict(zip(os_red_players, red_team_ratings))
93+
94+
os_players.update(os_blue_players)
95+
os_players.update(os_red_players)
96+
97+
98+
def process_ts_match(match: dict):
99+
result = match.get("result")
100+
won = True if result == "WIN" else False
101+
102+
teams: dict = match.get("teams")
103+
blue_team: dict = teams.get("blue")
104+
red_team: dict = teams.get("red")
105+
106+
ts_blue_players = {}
107+
ts_red_players = {}
108+
109+
for player in blue_team:
110+
ts_blue_players[player] = trueskill.Rating()
111+
112+
for player in red_team:
113+
ts_red_players[player] = trueskill.Rating()
114+
115+
if won:
116+
blue_team_ratings, red_team_ratings = trueskill.rate(
117+
[list(ts_blue_players.values()), list(ts_red_players.values())],
118+
)
119+
else:
120+
red_team_ratings, blue_team_ratings = trueskill.rate(
121+
[list(ts_red_players.values()), list(ts_blue_players.values())]
122+
)
123+
124+
ts_blue_players = dict(zip(ts_blue_players, blue_team_ratings))
125+
ts_red_players = dict(zip(ts_red_players, red_team_ratings))
126+
127+
ts_players.update(ts_blue_players)
128+
ts_players.update(ts_red_players)
129+
130+
131+
def predict_os_match(match: dict):
132+
result = match.get("result")
133+
won = True if result == "WIN" else False
134+
135+
teams: dict = match.get("teams")
136+
blue_team: dict = teams.get("blue")
137+
red_team: dict = teams.get("red")
138+
139+
os_blue_players = {}
140+
os_red_players = {}
141+
142+
for player in blue_team:
143+
os_blue_players[player] = os_players[player]
144+
145+
for player in red_team:
146+
os_red_players[player] = os_players[player]
147+
148+
blue_win_probability, red_win_probability = openskill.predict_win(
149+
[list(os_blue_players.values()), list(os_red_players.values())]
150+
)
151+
if (blue_win_probability > red_win_probability) == won:
152+
global os_correct_predictions
153+
os_correct_predictions += 1
154+
else:
155+
global os_incorrect_predictions
156+
os_incorrect_predictions += 1
157+
158+
159+
def win_probability(team1, team2):
160+
delta_mu = sum(r.mu for r in team1) - sum(r.mu for r in team2)
161+
sum_sigma = sum(r.sigma ** 2 for r in itertools.chain(team1, team2))
162+
size = len(team1) + len(team2)
163+
denom = math.sqrt(size * (trueskill.BETA * trueskill.BETA) + sum_sigma)
164+
ts = trueskill.global_env()
165+
return ts.cdf(delta_mu / denom)
166+
167+
168+
def predict_ts_match(match: dict):
169+
result = match.get("result")
170+
won = True if result == "WIN" else False
171+
172+
teams: dict = match.get("teams")
173+
blue_team: dict = teams.get("blue")
174+
red_team: dict = teams.get("red")
175+
176+
ts_blue_players = {}
177+
ts_red_players = {}
178+
179+
for player in blue_team:
180+
ts_blue_players[player] = ts_players[player]
181+
182+
for player in red_team:
183+
ts_red_players[player] = os_players[player]
184+
185+
blue_win_probability = win_probability(
186+
list(ts_blue_players.values()), list(ts_red_players.values())
187+
)
188+
red_win_probability = abs(1 - blue_win_probability)
189+
if (blue_win_probability > red_win_probability) == won:
190+
global ts_correct_predictions
191+
ts_correct_predictions += 1
192+
else:
193+
global ts_incorrect_predictions
194+
ts_incorrect_predictions += 1
195+
196+
197+
models = [
198+
BradleyTerryFull,
199+
BradleyTerryPart,
200+
PlackettLuce,
201+
ThurstoneMostellerFull,
202+
ThurstoneMostellerPart,
203+
]
204+
model_names = [m.__name__ for m in models]
205+
model_completer = WordCompleter(model_names)
206+
input_model = prompt("Enter Model: ", completer=model_completer)
207+
if input_model in model_names:
208+
index = model_names.index(input_model)
209+
else:
210+
print(HTML("<style fg='Red'>Model Not Found</style>"))
211+
quit()
212+
with jsonlines.open("v2_jsonl_teams.jsonl") as reader:
213+
lines = list(reader.iter())
214+
215+
# Process OpenSkill Ratings
216+
title = HTML(f'Updating Ratings with <style fg="Green">{input_model}</style> Model')
217+
with ProgressBar(title=title) as progress_bar:
218+
os_process_time_start = time.time()
219+
for line in progress_bar(lines, total=len(lines)):
220+
if data_verified(match=line):
221+
process_os_match(match=line, model=models[index])
222+
os_process_time_stop = time.time()
223+
os_time = os_process_time_stop - os_process_time_start
224+
225+
# Process TrueSkill Ratings
226+
title = HTML(f'Updating Ratings with <style fg="Green">TrueSkill</style> Model')
227+
with ProgressBar(title=title) as progress_bar:
228+
ts_process_time_start = time.time()
229+
for line in progress_bar(lines, total=len(lines)):
230+
if data_verified(match=line):
231+
process_ts_match(match=line)
232+
ts_process_time_stop = time.time()
233+
ts_time = ts_process_time_stop - ts_process_time_start
234+
235+
# Predict OpenSkill Matches
236+
title = HTML(f'<style fg="Blue">Predicting OpenSkill Matches:</style>')
237+
with ProgressBar(title=title) as progress_bar:
238+
for line in progress_bar(lines, total=len(lines)):
239+
if data_verified(match=line):
240+
predict_os_match(match=line)
241+
242+
# Predict TrueSkill Matches
243+
title = HTML(f'<style fg="Blue">Predicting TrueSkill Matches:</style>')
244+
with ProgressBar(title=title) as progress_bar:
245+
for line in progress_bar(lines, total=len(lines)):
246+
if data_verified(match=line):
247+
predict_ts_match(match=line)
248+
249+
250+
print(
251+
HTML(
252+
f"Predictions Made with OpenSkill's <style fg='Green'><u>{input_model}</u></style> Model:"
253+
)
254+
)
255+
print(
256+
HTML(
257+
f"Correct: <style fg='Yellow'>{os_correct_predictions}</style> | "
258+
f"Incorrect: <style fg='Yellow'>{os_incorrect_predictions}</style>"
259+
)
260+
)
261+
print(
262+
HTML(
263+
f"Accuracy: <style fg='Yellow'>"
264+
f"{round((os_correct_predictions/(os_incorrect_predictions + os_correct_predictions)) * 100, 2)}%"
265+
f"</style>"
266+
)
267+
)
268+
print(HTML(f"Process Duration: <style fg='Yellow'>{os_time}</style>"))
269+
print("-" * 40)
270+
print(HTML(f"Predictions Made with <style fg='Green'><u>TrueSkill</u></style> Model:"))
271+
print(
272+
HTML(
273+
f"Correct: <style fg='Yellow'>{ts_correct_predictions}</style> | "
274+
f"Incorrect: <style fg='Yellow'>{ts_incorrect_predictions}</style>"
275+
)
276+
)
277+
print(
278+
HTML(
279+
f"Accuracy: <style fg='Yellow'>"
280+
f"{round((ts_correct_predictions/(ts_incorrect_predictions + ts_correct_predictions)) * 100, 2)}%"
281+
f"</style>"
282+
)
283+
)
284+
print(HTML(f"Process Duration: <style fg='Yellow'>{ts_time}</style>"))

0 commit comments

Comments
 (0)