44from typing import Union
55
66import jsonlines
7+ import numpy as np
78import trueskill
89from prompt_toolkit import HTML
910from prompt_toolkit import print_formatted_text as print
1011from prompt_toolkit import prompt
1112from prompt_toolkit .completion import WordCompleter
1213from prompt_toolkit .shortcuts import ProgressBar
14+ from sklearn .model_selection import train_test_split
1315
1416import openskill
1517from openskill .models import (
2426os_players = {}
2527ts_players = {}
2628
29+ match_count = {}
30+
31+ matches = []
32+ training_set = {}
33+ test_set = {}
34+ valid_test_set_matches = []
35+
2736# Counters
2837os_correct_predictions = 0
2938os_incorrect_predictions = 0
3039ts_correct_predictions = 0
3140ts_incorrect_predictions = 0
41+ confident_matches = 0
3242
3343
3444print (HTML ("<u><b>Benchmark Starting</b></u>" ))
@@ -144,11 +154,15 @@ def predict_os_match(match: dict):
144154 for player in red_team :
145155 os_red_players [player ] = os_players [player ]
146156
147- blue_win_probability , red_win_probability = openskill .predict_win (
157+ blue_win_probability , red_win_probability = openskill .predict_rank (
148158 [list (os_blue_players .values ()), list (os_red_players .values ())]
149159 )
150- if (blue_win_probability > red_win_probability ) == won :
151- global os_correct_predictions
160+ blue_win_probability = blue_win_probability [0 ]
161+ red_win_probability = red_win_probability [0 ]
162+ global os_correct_predictions
163+ if (blue_win_probability < red_win_probability ) == won :
164+ os_correct_predictions += 1
165+ elif blue_win_probability == red_win_probability : # Draw
152166 os_correct_predictions += 1
153167 else :
154168 global os_incorrect_predictions
@@ -179,7 +193,7 @@ def predict_ts_match(match: dict):
179193 ts_blue_players [player ] = ts_players [player ]
180194
181195 for player in red_team :
182- ts_red_players [player ] = os_players [player ]
196+ ts_red_players [player ] = ts_players [player ]
183197
184198 blue_win_probability = win_probability (
185199 list (ts_blue_players .values ()), list (ts_red_players .values ())
@@ -193,6 +207,52 @@ def predict_ts_match(match: dict):
193207 ts_incorrect_predictions += 1
194208
195209
210+ def process_match (match : dict ):
211+ teams : dict = match .get ("teams" )
212+ blue_team : dict = teams .get ("blue" )
213+ red_team : dict = teams .get ("red" )
214+
215+ for player in blue_team :
216+ match_count [player ] = match_count .get (player , 0 ) + 1
217+
218+ for player in red_team :
219+ match_count [player ] = match_count .get (player , 0 ) + 1
220+
221+
222+ def valid_test_set (match : dict ):
223+ teams : dict = match .get ("teams" )
224+ blue_team : dict = teams .get ("blue" )
225+ red_team : dict = teams .get ("red" )
226+
227+ for player in blue_team :
228+ if player not in os_players :
229+ return False
230+
231+ for player in red_team :
232+ if player not in os_players :
233+ return False
234+
235+ return True
236+
237+
238+ def confident_in_match (match : dict ) -> bool :
239+ teams : dict = match .get ("teams" )
240+ blue_team : dict = teams .get ("blue" )
241+ red_team : dict = teams .get ("red" )
242+
243+ global confident_matches
244+ for player in blue_team :
245+ if match_count [player ] < 2 :
246+ return False
247+
248+ for player in red_team :
249+ if match_count [player ] < 2 :
250+ return False
251+
252+ confident_matches += 1
253+ return True
254+
255+
196256models = [
197257 BradleyTerryFull ,
198258 BradleyTerryPart ,
@@ -203,6 +263,7 @@ def predict_ts_match(match: dict):
203263model_names = [m .__name__ for m in models ]
204264model_completer = WordCompleter (model_names )
205265input_model = prompt ("Enter Model: " , completer = model_completer )
266+
206267if input_model in model_names :
207268 index = model_names .index (input_model )
208269else :
@@ -211,41 +272,71 @@ def predict_ts_match(match: dict):
211272with jsonlines .open ("v2_jsonl_teams.jsonl" ) as reader :
212273 lines = list (reader .iter ())
213274
214- # Process OpenSkill Ratings
215- title = HTML (f'Updating Ratings with <style fg="Green">{ input_model } </style> Model' )
275+ title = HTML (f'<style fg="Red">Processing Matches</style>' )
216276 with ProgressBar (title = title ) as progress_bar :
217- os_process_time_start = time .time ()
218277 for line in progress_bar (lines , total = len (lines )):
219278 if data_verified (match = line ):
220- process_os_match (match = line , model = models [index ])
279+ process_match (match = line )
280+
281+ # Measure Confidence
282+ title = HTML (f'<style fg="Red">Splitting Data</style>' )
283+ with ProgressBar (title = title ) as progress_bar :
284+ for line in progress_bar (lines , total = len (lines )):
285+ if data_verified (match = line ):
286+ if confident_in_match (match = line ):
287+ matches .append (line )
288+
289+ # Split Data
290+ training_set , test_set = train_test_split (
291+ matches , test_size = 0.33 , random_state = True
292+ )
293+
294+ # Process OpenSkill Ratings
295+ title = HTML (
296+ f'Updating Ratings with <style fg="Green">{ input_model } </style> Model:'
297+ )
298+ with ProgressBar (title = title ) as progress_bar :
299+ os_process_time_start = time .time ()
300+ for line in progress_bar (training_set , total = len (training_set )):
301+ process_os_match (match = line , model = models [index ])
221302 os_process_time_stop = time .time ()
222303 os_time = os_process_time_stop - os_process_time_start
223304
224305 # Process TrueSkill Ratings
225- title = HTML (f'Updating Ratings with <style fg="Green">TrueSkill</style> Model' )
306+ title = HTML (f'Updating Ratings with <style fg="Green">TrueSkill</style> Model: ' )
226307 with ProgressBar (title = title ) as progress_bar :
227308 ts_process_time_start = time .time ()
228- for line in progress_bar (lines , total = len (lines )):
229- if data_verified (match = line ):
230- process_ts_match (match = line )
309+ for line in progress_bar (training_set , total = len (training_set )):
310+ process_ts_match (match = line )
231311 ts_process_time_stop = time .time ()
232312 ts_time = ts_process_time_stop - ts_process_time_start
233313
314+ # Process Test Set
315+ title = HTML (f'<style fg="Red">Processing Test Set</style>' )
316+ with ProgressBar (title = title ) as progress_bar :
317+ for line in progress_bar (test_set , total = len (test_set )):
318+ if valid_test_set (match = line ):
319+ valid_test_set_matches .append (line )
320+
234321 # Predict OpenSkill Matches
235322 title = HTML (f'<style fg="Blue">Predicting OpenSkill Matches:</style>' )
236323 with ProgressBar (title = title ) as progress_bar :
237- for line in progress_bar (lines , total = len (lines )):
238- if data_verified (match = line ):
239- predict_os_match (match = line )
324+ for line in progress_bar (
325+ valid_test_set_matches , total = len (valid_test_set_matches )
326+ ):
327+ predict_os_match (match = line )
240328
241329 # Predict TrueSkill Matches
242330 title = HTML (f'<style fg="Blue">Predicting TrueSkill Matches:</style>' )
243331 with ProgressBar (title = title ) as progress_bar :
244- for line in progress_bar (lines , total = len (lines )):
245- if data_verified (match = line ):
246- predict_ts_match (match = line )
332+ for line in progress_bar (
333+ valid_test_set_matches , total = len (valid_test_set_matches )
334+ ):
335+ predict_ts_match (match = line )
247336
337+ mean = float (np .array (list (match_count .values ())).mean ())
248338
339+ print (HTML (f"Confident Matches: <style fg='Yellow'>{ confident_matches } </style>" ))
249340print (
250341 HTML (
251342 f"Predictions Made with OpenSkill's <style fg='Green'><u>{ input_model } </u></style> Model:"
@@ -281,3 +372,4 @@ def predict_ts_match(match: dict):
281372 )
282373)
283374print (HTML (f"Process Duration: <style fg='Yellow'>{ ts_time } </style>" ))
375+ print (HTML (f"Mean Matches: <style fg='Yellow'>{ mean } </style>" ))
0 commit comments