diff --git a/tcllib/requests/serverselector.py b/tcllib/requests/serverselector.py index 2877353..cd2236f 100644 --- a/tcllib/requests/serverselector.py +++ b/tcllib/requests/serverselector.py @@ -20,14 +20,18 @@ MASTER_SERVERS = [ class ServerSelector: """Returns a random server to use.""" - def __init__(self): + def __init__(self, server_list=None): """Init stuff""" + if server_list: + self.server_list = server_list + else: + self.server_list = MASTER_SERVERS self.last_server = None def get_master_server(self): """Return a random server.""" while True: - new_server = numpy.random.choice(MASTER_SERVERS) + new_server = numpy.random.choice(self.server_list) if new_server != self.last_server: break self.last_server = new_server @@ -44,35 +48,35 @@ class ServerSelector: class ServerVoteSelector(ServerSelector): """Tries to return faster servers more often.""" - def __init__(self): + def __init__(self, server_list=None): """Populate server list and weighting variables.""" - self.last_server = None - self.master_servers_weights = [3] * len(MASTER_SERVERS) + super().__init__(server_list) + self.servers_weights = [3] * len(self.server_list) self.check_time_sum = 3 self.check_time_count = 1 def get_master_server(self): """Return weighted choice from server list.""" weight_sum = 0 - for i in self.master_servers_weights: + for i in self.servers_weights: weight_sum += i numpy_weights = [] - for i in self.master_servers_weights: + for i in self.servers_weights: numpy_weights.append(i/weight_sum) - self.last_server = numpy.random.choice(MASTER_SERVERS, p=numpy_weights) + self.last_server = numpy.random.choice(self.server_list, p=numpy_weights) return self.last_server def master_server_downvote(self): """Decrease weight of last chosen server.""" - idx = MASTER_SERVERS.index(self.last_server) - if self.master_servers_weights[idx] > 1: - self.master_servers_weights[idx] -= 1 + idx = self.server_list.index(self.last_server) + if self.servers_weights[idx] > 1: + self.servers_weights[idx] -= 1 def master_server_upvote(self): """Increase weight of last chosen server.""" - idx = MASTER_SERVERS.index(self.last_server) - if self.master_servers_weights[idx] < 10: - self.master_servers_weights[idx] += 1 + idx = self.server_list.index(self.last_server) + if self.servers_weights[idx] < 10: + self.servers_weights[idx] += 1 def check_time_add(self, duration): """Record connection time."""