99from prompt_toolkit import HTML
1010from prompt_toolkit import print_formatted_text as print
1111from prompt_toolkit .shortcuts import ProgressBar
12+ from prompt_toolkit .styles import Style
1213from sklearn .model_selection import train_test_split
1314
14- import openskill
1515from openskill .models import (
1616 BradleyTerryFull ,
1717 BradleyTerryPart ,
@@ -46,6 +46,7 @@ def __init__(
4646 self ,
4747 path ,
4848 seed : int ,
49+ minimum_matches : int ,
4950 model : Union [
5051 BradleyTerryFull ,
5152 BradleyTerryPart ,
@@ -60,11 +61,13 @@ def __init__(
6061 for match_index , row in df .iterrows ():
6162 self .data .append (row )
6263 self .seed = seed
64+ self .minimum_matches = minimum_matches
6365 self .model = model
6466
6567 # Counters
6668 self .match_count = {}
67- self .confident_matches = 0
69+ self .available_matches = 0
70+ self .valid_matches = 0
6871 self .openskill_correct_predictions = 0
6972 self .openskill_incorrect_predictions = 0
7073 self .trueskill_correct_predictions = 0
@@ -87,15 +90,24 @@ def __init__(
8790 self .trueskill_time = None
8891
8992 def process (self ):
93+ style = Style .from_dict (
94+ {
95+ "label" : "bg:#ffff00 #000000" ,
96+ "percentage" : "bg:#ffff00 #000000" ,
97+ "current" : "#448844" ,
98+ "bar" : "" ,
99+ }
100+ )
101+
90102 title = HTML (f'<style fg="Red">Counting Matches</style>' )
91- with ProgressBar (title = title ) as progress_bar :
103+ with ProgressBar (title = title , style = style ) as progress_bar :
92104 for match in progress_bar (self .data , total = len (self .data )):
93105 if self .consistent (match = match ):
94106 self .count (match = match )
95107
96108 # Check if data has sufficient history.
97109 title = HTML (f'<style fg="Red">Verifying History</style>' )
98- with ProgressBar (title = title ) as progress_bar :
110+ with ProgressBar (title = title , style = style ) as progress_bar :
99111 for match in progress_bar (self .data , total = len (self .data )):
100112 if self .consistent (match = match ):
101113 if self .has_sufficient_history (match = match ):
@@ -111,7 +123,7 @@ def process(self):
111123 title = HTML (
112124 f'Updating OpenSkill Ratings with <style fg="Green">{ self .model .__name__ } </style> Model:'
113125 )
114- with ProgressBar (title = title ) as progress_bar :
126+ with ProgressBar (title = title , style = style ) as progress_bar :
115127 os_process_time_start = time .time ()
116128 for match in progress_bar (self .training_set , total = len (self .training_set )):
117129 self .process_openskill (match = match )
@@ -122,7 +134,7 @@ def process(self):
122134 title = HTML (
123135 f'Updating Ratings with <style fg="Green">TrueSkill</style> Model:'
124136 )
125- with ProgressBar (title = title ) as progress_bar :
137+ with ProgressBar (title = title , style = style ) as progress_bar :
126138 ts_process_time_start = time .time ()
127139 for match in progress_bar (self .training_set , total = len (self .training_set )):
128140 self .process_trueskill (match = match )
@@ -131,22 +143,23 @@ def process(self):
131143
132144 # Process Test Set
133145 title = HTML (f'<style fg="Red">Processing Test Set</style>' )
134- with ProgressBar (title = title ) as progress_bar :
146+ with ProgressBar (title = title , style = style ) as progress_bar :
135147 for match in progress_bar (self .test_set , total = len (self .test_set )):
136148 if self .valid_test (match = match ):
137149 self .verified_test_set .append (match )
150+ self .valid_matches += 1
138151
139152 # Predict OpenSkill Matches
140153 title = HTML (f'<style fg="Blue">Predicting OpenSkill Matches:</style>' )
141- with ProgressBar (title = title ) as progress_bar :
154+ with ProgressBar (title = title , style = style ) as progress_bar :
142155 for match in progress_bar (
143156 self .verified_test_set , total = len (self .verified_test_set )
144157 ):
145158 self .predict_openskill (match = match )
146159
147160 # Predict TrueSkill Matches
148161 title = HTML (f'<style fg="Blue">Predicting TrueSkill Matches:</style>' )
149- with ProgressBar (title = title ) as progress_bar :
162+ with ProgressBar (title = title , style = style ) as progress_bar :
150163 for match in progress_bar (
151164 self .verified_test_set , total = len (self .verified_test_set )
152165 ):
@@ -157,9 +170,10 @@ def print_result(self):
157170 print ("-" * 40 )
158171 print (
159172 HTML (
160- f"Confident Matches: <style fg='Yellow'>{ self .confident_matches } </style>"
173+ f"Available Matches: <style fg='Yellow'>{ self .available_matches } </style>"
161174 )
162175 )
176+ print (HTML (f"Valid Matches: <style fg='Yellow'>{ self .valid_matches } </style>" ))
163177 print (
164178 HTML (
165179 f"Predictions Made with OpenSkill's <style fg='Green'><u>{ self .model .__name__ } </u></style> Model:"
@@ -278,13 +292,13 @@ def has_sufficient_history(self, match):
278292 white_player : dict = match ["white_username" ]
279293 black_player : dict = match ["black_username" ]
280294
281- if self .match_count [white_player ] < 2 :
295+ if self .match_count [white_player ] < self . minimum_matches :
282296 return False
283297
284- if self .match_count [black_player ] < 2 :
298+ if self .match_count [black_player ] < self . minimum_matches :
285299 return False
286300
287- self .confident_matches += 1
301+ self .available_matches += 1
288302 return True
289303
290304 def process_openskill (self , match ):
0 commit comments