Files
polar-flow-analysis/limit-calls.py
2015-12-25 11:09:42 +01:00

132 lines
5.3 KiB
Python

import time
import threading
import base64
import atexit
import sqlite3
import sys
class Persistent:
pass
class SQLiteImplementation(Persistent):
def __init__(self, database='limit-calls.db', session='-'):
self.connection = sqlite3.connect(database)
self.cursor = self.connection.cursor()
self.cursor.execute('DROP TABLE IF EXISTS per_second')
self.cursor.execute('DROP TABLE IF EXISTS rate')
self.cursor.execute('CREATE TABLE IF NOT EXISTS per_second (id INTEGER PRIMARY KEY AUTOINCREMENT, target string, hash string, last_reset real, calls int, expire real)')
self.cursor.execute('CREATE TABLE IF NOT EXISTS rate (id INTEGER PRIMARY KEY AUTOINCREMENT, target string, hash string, last_call real, expire real)')
self.connection.commit()
def __del__(self):
self.cursor.execute('DELETE from per_second WHERE expire > {:f}'.format(time.time()))
self.cursor.execute('DELETE from rate WHERE expire > {:f}'.format(time.time()))
self.connection.commit()
self.connection.close()
def getCallsPerSecond(self, function, hash):
target = function.__name__
query = "SELECT last_reset, calls FROM per_second WHERE target = '{:s}' AND hash = '{:s}'".format(function.__name__, hash)
self.cursor.execute(query);
row = self.cursor.fetchone()
if row is None:
return ([0.0], [0])
else:
return ([row[0]], [row[1]])
def setCallsPerSecond(self, function, hash, lastReset, calls, expire):
query = "REPLACE INTO per_second (id, target, hash, last_reset, calls, expire) VALUES ((SELECT id FROM per_second WHERE target = '{0:s}' AND hash = '{1:s}'), '{0:s}', '{1:s}', {2:f}, {3:d}, {4:f})".format(function.__name__, hash, lastReset, calls, expire)
self.cursor.execute(query);
self.connection.commit()
def getCallsRate(self, function, hash):
target = function.__name__
query = "SELECT last_call FROM rate WHERE target = '{:s}' AND hash = '{:s}'".format(function.__name__, hash)
self.cursor.execute(query)
row = self.cursor.fetchone()
if row is None:
return ([0.0])
else:
return ([row[0]])
def setCallsRate(self, function, hash, lastCall, expire):
query = "REPLACE INTO rate (id, target, hash, last_call, expire) VALUES ((SELECT id FROM rate WHERE target = '{0:s}' AND hash = '{1:s}'), '{0:s}', '{1:s}', {2:f}, {3:f})".format(function.__name__, hash, lastCall, expire)
self.cursor.execute(query)
self.connection.commit()
def limitCallsPerSecond(maxCalls, perSeconds, persistent, sleep=True):
def decorate(function):
lock = threading.RLock()
hash = base64.b64encode('%d-%d-%s' % (maxCalls, perSeconds, sleep))
(lastReset, calls) = persistent.getCallsPerSecond(function, hash)
def store(expire):
persistent.setCallsPerSecond(function, hash, lastReset[0], calls[0], expire)
def reset(time=time.time()):
lastReset[0] = time
calls[0] = maxCalls
store(time + perSeconds)
def wrapper(*args, **kargs):
lock.acquire()
now = time.time()
sinceLastReset = now - lastReset[0]
if sinceLastReset > perSeconds:
reset(now)
else:
calls[0] = calls[0] - 1
store(now + perSeconds)
outOfCalls = calls[0] < 1
if outOfCalls and sleep:
leftToWait = perSeconds - sinceLastReset
time.sleep(leftToWait)
reset()
leftToWait = False
lock.release()
if outOfCalls is False:
return function(*args, **kargs)
return wrapper
return decorate
def limitCallsRate(maxPerSecond, perSecond, persistent, sleep=True):
def decorate(function):
lock = threading.RLock()
minInterval = perSecond / float(maxPerSecond)
hash = base64.b64encode(('%d-%d-%s' % (maxPerSecond, perSecond, sleep)).encode()).decode()
print(hash)
lastCall = persistent.getCallsRate(function, hash)
def store(expire):
persistent.setCallsRate(function, hash, lastCall[0], expire)
def wrapper(*args, **kargs):
lock.acquire()
elapsed = time.time() - lastCall[0]
leftToWait = minInterval - elapsed
if leftToWait > 0:
if sleep:
time.sleep(leftToWait)
else:
lock.release()
return
try:
toReturn = function(*args, **kargs)
finally:
lastCall[0] = time.time()
store(lastCall[0] + minInterval)
lock.release()
return toReturn
return wrapper
return decorate
if __name__ == "__main__":
persistent = SQLiteImplementation()
@limitCallsPerSecond(3, 4, persistent)
@limitCallsRate(2, 1, persistent)
def PrintNumber(num):
print("%s: %d" % (time.time(), num))
time.sleep(0.01)
return True
i = 1
while i < 10000:
if not PrintNumber(i):
time.sleep(0.1)
i = i + 1