132 lines
5.3 KiB
Python
132 lines
5.3 KiB
Python
import time
|
|
import threading
|
|
import base64
|
|
import atexit
|
|
import sqlite3
|
|
import sys
|
|
|
|
class Persistent:
|
|
pass
|
|
|
|
class SQLiteImplementation(Persistent):
|
|
def __init__(self, database='limit-calls.db', session='-'):
|
|
self.connection = sqlite3.connect(database)
|
|
self.cursor = self.connection.cursor()
|
|
self.cursor.execute('DROP TABLE IF EXISTS per_second')
|
|
self.cursor.execute('DROP TABLE IF EXISTS rate')
|
|
self.cursor.execute('CREATE TABLE IF NOT EXISTS per_second (id INTEGER PRIMARY KEY AUTOINCREMENT, target string, hash string, last_reset real, calls int, expire real)')
|
|
self.cursor.execute('CREATE TABLE IF NOT EXISTS rate (id INTEGER PRIMARY KEY AUTOINCREMENT, target string, hash string, last_call real, expire real)')
|
|
self.connection.commit()
|
|
|
|
def __del__(self):
|
|
self.cursor.execute('DELETE from per_second WHERE expire > {:f}'.format(time.time()))
|
|
self.cursor.execute('DELETE from rate WHERE expire > {:f}'.format(time.time()))
|
|
self.connection.commit()
|
|
self.connection.close()
|
|
|
|
def getCallsPerSecond(self, function, hash):
|
|
target = function.__name__
|
|
query = "SELECT last_reset, calls FROM per_second WHERE target = '{:s}' AND hash = '{:s}'".format(function.__name__, hash)
|
|
self.cursor.execute(query);
|
|
row = self.cursor.fetchone()
|
|
if row is None:
|
|
return ([0.0], [0])
|
|
else:
|
|
return ([row[0]], [row[1]])
|
|
|
|
def setCallsPerSecond(self, function, hash, lastReset, calls, expire):
|
|
query = "REPLACE INTO per_second (id, target, hash, last_reset, calls, expire) VALUES ((SELECT id FROM per_second WHERE target = '{0:s}' AND hash = '{1:s}'), '{0:s}', '{1:s}', {2:f}, {3:d}, {4:f})".format(function.__name__, hash, lastReset, calls, expire)
|
|
self.cursor.execute(query);
|
|
self.connection.commit()
|
|
|
|
def getCallsRate(self, function, hash):
|
|
target = function.__name__
|
|
query = "SELECT last_call FROM rate WHERE target = '{:s}' AND hash = '{:s}'".format(function.__name__, hash)
|
|
self.cursor.execute(query)
|
|
row = self.cursor.fetchone()
|
|
if row is None:
|
|
return ([0.0])
|
|
else:
|
|
return ([row[0]])
|
|
|
|
def setCallsRate(self, function, hash, lastCall, expire):
|
|
query = "REPLACE INTO rate (id, target, hash, last_call, expire) VALUES ((SELECT id FROM rate WHERE target = '{0:s}' AND hash = '{1:s}'), '{0:s}', '{1:s}', {2:f}, {3:f})".format(function.__name__, hash, lastCall, expire)
|
|
self.cursor.execute(query)
|
|
self.connection.commit()
|
|
|
|
def limitCallsPerSecond(maxCalls, perSeconds, persistent, sleep=True):
|
|
def decorate(function):
|
|
lock = threading.RLock()
|
|
hash = base64.b64encode('%d-%d-%s' % (maxCalls, perSeconds, sleep))
|
|
(lastReset, calls) = persistent.getCallsPerSecond(function, hash)
|
|
def store(expire):
|
|
persistent.setCallsPerSecond(function, hash, lastReset[0], calls[0], expire)
|
|
def reset(time=time.time()):
|
|
lastReset[0] = time
|
|
calls[0] = maxCalls
|
|
store(time + perSeconds)
|
|
def wrapper(*args, **kargs):
|
|
lock.acquire()
|
|
now = time.time()
|
|
sinceLastReset = now - lastReset[0]
|
|
if sinceLastReset > perSeconds:
|
|
reset(now)
|
|
else:
|
|
calls[0] = calls[0] - 1
|
|
store(now + perSeconds)
|
|
outOfCalls = calls[0] < 1
|
|
if outOfCalls and sleep:
|
|
leftToWait = perSeconds - sinceLastReset
|
|
time.sleep(leftToWait)
|
|
reset()
|
|
leftToWait = False
|
|
lock.release()
|
|
if outOfCalls is False:
|
|
return function(*args, **kargs)
|
|
return wrapper
|
|
return decorate
|
|
|
|
def limitCallsRate(maxPerSecond, perSecond, persistent, sleep=True):
|
|
def decorate(function):
|
|
lock = threading.RLock()
|
|
minInterval = perSecond / float(maxPerSecond)
|
|
hash = base64.b64encode(('%d-%d-%s' % (maxPerSecond, perSecond, sleep)).encode()).decode()
|
|
print(hash)
|
|
lastCall = persistent.getCallsRate(function, hash)
|
|
def store(expire):
|
|
persistent.setCallsRate(function, hash, lastCall[0], expire)
|
|
def wrapper(*args, **kargs):
|
|
lock.acquire()
|
|
elapsed = time.time() - lastCall[0]
|
|
leftToWait = minInterval - elapsed
|
|
if leftToWait > 0:
|
|
if sleep:
|
|
time.sleep(leftToWait)
|
|
else:
|
|
lock.release()
|
|
return
|
|
try:
|
|
toReturn = function(*args, **kargs)
|
|
finally:
|
|
lastCall[0] = time.time()
|
|
store(lastCall[0] + minInterval)
|
|
lock.release()
|
|
return toReturn
|
|
return wrapper
|
|
return decorate
|
|
|
|
if __name__ == "__main__":
|
|
persistent = SQLiteImplementation()
|
|
@limitCallsPerSecond(3, 4, persistent)
|
|
@limitCallsRate(2, 1, persistent)
|
|
|
|
def PrintNumber(num):
|
|
print("%s: %d" % (time.time(), num))
|
|
time.sleep(0.01)
|
|
return True
|
|
|
|
i = 1
|
|
while i < 10000:
|
|
if not PrintNumber(i):
|
|
time.sleep(0.1)
|
|
i = i + 1 |