diff --git a/src/main/java/redis/clients/jedis/ShardedJedis.java b/src/main/java/redis/clients/jedis/ShardedJedis.java new file mode 100644 index 0000000..d3b6266 --- /dev/null +++ b/src/main/java/redis/clients/jedis/ShardedJedis.java @@ -0,0 +1,356 @@ +package redis.clients.jedis; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import redis.clients.util.ShardInfo; +import redis.clients.util.Sharded; + +public class ShardedJedis extends Sharded { + public ShardedJedis(List shards) { + super(shards); + } + + public String set(String key, String value) { + Jedis j = getShard(key); + return j.set(key, value); + } + + public String get(String key) { + Jedis j = getShard(key); + return j.get(key); + } + + public int exists(String key) { + Jedis j = getShard(key); + return j.exists(key); + } + + public String type(String key) { + Jedis j = getShard(key); + return j.type(key); + } + + public int expire(String key, int seconds) { + Jedis j = getShard(key); + return j.expire(key, seconds); + } + + public int expireAt(String key, long unixTime) { + Jedis j = getShard(key); + return j.expireAt(key, unixTime); + } + + public int ttl(String key) { + Jedis j = getShard(key); + return j.ttl(key); + } + + public String getSet(String key, String value) { + Jedis j = getShard(key); + return j.getSet(key, value); + } + + public int setnx(String key, String value) { + Jedis j = getShard(key); + return j.setnx(key, value); + } + + public String setex(String key, int seconds, String value) { + Jedis j = getShard(key); + return j.setex(key, seconds, value); + } + + public int decrBy(String key, int integer) { + Jedis j = getShard(key); + return j.decrBy(key, integer); + } + + public int decr(String key) { + Jedis j = getShard(key); + return j.decr(key); + } + + public int incrBy(String key, int integer) { + Jedis j = getShard(key); + return j.incrBy(key, integer); + } + + public int incr(String key) { + Jedis j = getShard(key); + return j.incr(key); + } + + public int append(String key, String value) { + Jedis j = getShard(key); + return j.append(key, value); + } + + public String substr(String key, int start, int end) { + Jedis j = getShard(key); + return j.substr(key, start, end); + } + + public int hset(String key, String field, String value) { + Jedis j = getShard(key); + return j.hset(key, field, value); + } + + public String hget(String key, String field) { + Jedis j = getShard(key); + return j.hget(key, field); + } + + public int hsetnx(String key, String field, String value) { + Jedis j = getShard(key); + return j.hsetnx(key, field, value); + } + + public String hmset(String key, Map hash) { + Jedis j = getShard(key); + return j.hmset(key, hash); + } + + public List hmget(String key, String... fields) { + Jedis j = getShard(key); + return j.hmget(key, fields); + } + + public int hincrBy(String key, String field, int value) { + Jedis j = getShard(key); + return j.hincrBy(key, field, value); + } + + public int hexists(String key, String field) { + Jedis j = getShard(key); + return j.hexists(key, field); + } + + public int hdel(String key, String field) { + Jedis j = getShard(key); + return j.hdel(key, field); + } + + public int hlen(String key) { + Jedis j = getShard(key); + return j.hlen(key); + } + + public List hkeys(String key) { + Jedis j = getShard(key); + return j.hkeys(key); + } + + public List hvals(String key) { + Jedis j = getShard(key); + return j.hvals(key); + } + + public Map hgetAll(String key) { + Jedis j = getShard(key); + return j.hgetAll(key); + } + + public int rpush(String key, String string) { + Jedis j = getShard(key); + return j.rpush(key, string); + } + + public int lpush(String key, String string) { + Jedis j = getShard(key); + return j.lpush(key, string); + } + + public int llen(String key) { + Jedis j = getShard(key); + return j.llen(key); + } + + public List lrange(String key, int start, int end) { + Jedis j = getShard(key); + return j.lrange(key, start, end); + } + + public String ltrim(String key, int start, int end) { + Jedis j = getShard(key); + return j.ltrim(key, start, end); + } + + public String lindex(String key, int index) { + Jedis j = getShard(key); + return j.lindex(key, index); + } + + public String lset(String key, int index, String value) { + Jedis j = getShard(key); + return j.lset(key, index, value); + } + + public int lrem(String key, int count, String value) { + Jedis j = getShard(key); + return j.lrem(key, count, value); + } + + public String lpop(String key) { + Jedis j = getShard(key); + return j.lpop(key); + } + + public String rpop(String key) { + Jedis j = getShard(key); + return j.rpop(key); + } + + public int sadd(String key, String member) { + Jedis j = getShard(key); + return j.sadd(key, member); + } + + public Set smembers(String key) { + Jedis j = getShard(key); + return j.smembers(key); + } + + public int srem(String key, String member) { + Jedis j = getShard(key); + return j.srem(key, member); + } + + public String spop(String key) { + Jedis j = getShard(key); + return j.spop(key); + } + + public int scard(String key) { + Jedis j = getShard(key); + return j.scard(key); + } + + public int sismember(String key, String member) { + Jedis j = getShard(key); + return j.sismember(key, member); + } + + public String srandmember(String key) { + Jedis j = getShard(key); + return j.srandmember(key); + } + + public int zadd(String key, double score, String member) { + Jedis j = getShard(key); + return j.zadd(key, score, member); + } + + public Set zrange(String key, int start, int end) { + Jedis j = getShard(key); + return j.zrange(key, start, end); + } + + public int zrem(String key, String member) { + Jedis j = getShard(key); + return j.zrem(key, member); + } + + public double zincrby(String key, double score, String member) { + Jedis j = getShard(key); + return j.zincrby(key, score, member); + } + + public int zrank(String key, String member) { + Jedis j = getShard(key); + return j.zrank(key, member); + } + + public int zrevrank(String key, String member) { + Jedis j = getShard(key); + return j.zrevrank(key, member); + } + + public Set zrevrange(String key, int start, int end) { + Jedis j = getShard(key); + return j.zrevrange(key, start, end); + } + + public Set zrangeWithScores(String key, int start, int end) { + Jedis j = getShard(key); + return j.zrangeWithScores(key, start, end); + } + + public Set zrevrangeWithScores(String key, int start, int end) { + Jedis j = getShard(key); + return j.zrevrangeWithScores(key, start, end); + } + + public int zcard(String key) { + Jedis j = getShard(key); + return j.zcard(key); + } + + public double zscore(String key, String member) { + Jedis j = getShard(key); + return j.zscore(key, member); + } + + public List sort(String key) { + Jedis j = getShard(key); + return j.sort(key); + } + + public List sort(String key, SortingParams sortingParameters) { + Jedis j = getShard(key); + return j.sort(key, sortingParameters); + } + + public int zcount(String key, double min, double max) { + Jedis j = getShard(key); + return j.zcount(key, min, max); + } + + public Set zrangeByScore(String key, double min, double max) { + Jedis j = getShard(key); + return j.zrangeByScore(key, min, max); + } + + public Set zrangeByScore(String key, double min, double max, + int offset, int count) { + Jedis j = getShard(key); + return j.zrangeByScore(key, min, max, offset, count); + } + + public Set zrangeByScoreWithScores(String key, double min, double max) { + Jedis j = getShard(key); + return j.zrangeByScoreWithScores(key, min, max); + } + + public Set zrangeByScoreWithScores(String key, double min, + double max, int offset, int count) { + Jedis j = getShard(key); + return j.zrangeByScoreWithScores(key, min, max, offset, count); + } + + public int zremrangeByRank(String key, int start, int end) { + Jedis j = getShard(key); + return j.zremrangeByRank(key, start, end); + } + + public int zremrangeByScore(String key, double start, double end) { + Jedis j = getShard(key); + return j.zremrangeByScore(key, start, end); + } + + public void disconnect() throws IOException { + for (Jedis jedis : getAllShards()) { + jedis.disconnect(); + } + } + + protected Jedis create(ShardInfo shard) { + Jedis c = new Jedis(shard.getHost(), shard.getPort()); + if (shard.getPassword() != null) { + c.auth(shard.getPassword()); + } + return c; + } +} \ No newline at end of file diff --git a/src/main/java/redis/clients/util/ShardInfo.java b/src/main/java/redis/clients/util/ShardInfo.java new file mode 100644 index 0000000..3a14ea1 --- /dev/null +++ b/src/main/java/redis/clients/util/ShardInfo.java @@ -0,0 +1,98 @@ +package redis.clients.util; + +import redis.clients.jedis.Protocol; + +public class ShardInfo { + @Override + public String toString() { + return "ShardInfo [host=" + host + ", port=" + port + ", weight=" + + weight + "]"; + } + + private String host; + private int port; + private int timeout; + private int weight; + private String password = null; + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + public int getTimeout() { + return timeout; + } + + public ShardInfo(String host) { + this(host, Protocol.DEFAULT_PORT); + } + + public ShardInfo(String host, int port) { + this(host, port, 2000); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((host == null) ? 0 : host.hashCode()); + result = prime * result + port; + result = prime * result + timeout; + result = prime * result + weight; + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + ShardInfo other = (ShardInfo) obj; + if (host == null) { + if (other.host != null) + return false; + } else if (!host.equals(other.host)) + return false; + if (port != other.port) + return false; + if (timeout != other.timeout) + return false; + if (weight != other.weight) + return false; + return true; + } + + public ShardInfo(String host, int port, int timeout) { + this(host, port, timeout, Sharded.DEFAULT_WEIGHT); + } + + public ShardInfo(String host, int port, int timeout, int weight) { + this.host = host; + this.port = port; + this.timeout = timeout; + this.weight = weight; + } + + public String getPassword() { + return password; + } + + public void setPassword(String auth) { + this.password = auth; + } + + public void setTimeout(int timeout) { + this.timeout = timeout; + } + + public int getWeight() { + return this.weight; + } +} diff --git a/src/main/java/redis/clients/util/Sharded.java b/src/main/java/redis/clients/util/Sharded.java new file mode 100644 index 0000000..e626f48 --- /dev/null +++ b/src/main/java/redis/clients/util/Sharded.java @@ -0,0 +1,102 @@ +package redis.clients.util; + +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +public abstract class Sharded { + public static final int DEFAULT_WEIGHT = 1; + private static MessageDigest md5 = null; // avoid recurring construction + private TreeMap nodes; + private int totalWeight; + private Map resources; + + public Sharded(List shards) { + initialize(shards); + } + + private void initialize(List shards) { + nodes = new TreeMap(); + resources = new HashMap(); + + totalWeight = 0; + + for (ShardInfo shard : shards) { + totalWeight += shard.getWeight(); + } + + MessageDigest md5; + try { + md5 = MessageDigest.getInstance("MD5"); + } catch (NoSuchAlgorithmException e) { + throw new IllegalStateException("++++ no md5 algorythm found"); + } + + for (ShardInfo shard : shards) { + double factor = Math + .floor(((double) (40 * shards.size() * DEFAULT_WEIGHT)) + / (double) totalWeight); + + for (long j = 0; j < factor; j++) { + byte[] d = md5.digest((shard.toString() + "-" + j).getBytes()); + for (int h = 0; h < 4; h++) { + Long k = ((long) (d[3 + h * 4] & 0xFF) << 24) + | ((long) (d[2 + h * 4] & 0xFF) << 16) + | ((long) (d[1 + h * 4] & 0xFF) << 8) + | ((long) (d[0 + h * 4] & 0xFF)); + nodes.put(k, shard); + } + } + resources.put(shard, create(shard)); + } + } + + public ShardInfo getShardInfo(String key) { + long hv = calculateHash(key); + + return nodes.get(findPointFor(hv)); + } + + private Long calculateHash(String key) { + if (md5 == null) { + try { + md5 = MessageDigest.getInstance("MD5"); + } catch (NoSuchAlgorithmException e) { + throw new IllegalStateException("++++ no md5 algorythm found"); + } + } + + md5.reset(); + md5.update(key.getBytes()); + byte[] bKey = md5.digest(); + long res = ((long) (bKey[3] & 0xFF) << 24) + | ((long) (bKey[2] & 0xFF) << 16) + | ((long) (bKey[1] & 0xFF) << 8) | (long) (bKey[0] & 0xFF); + return res; + } + + private Long findPointFor(Long hashK) { + Long k = nodes.ceilingKey(hashK); + + if (k == null) { + k = nodes.firstKey(); + } + + return k; + } + + public T getShard(String key) { + ShardInfo shard = getShardInfo(key); + return resources.get(shard); + } + + protected abstract T create(ShardInfo shard); + + public Collection getAllShards() { + return resources.values(); + } +} \ No newline at end of file diff --git a/src/test/java/redis/clients/jedis/tests/JedisTest.java b/src/test/java/redis/clients/jedis/tests/JedisTest.java index 2af5e42..1888cd2 100644 --- a/src/test/java/redis/clients/jedis/tests/JedisTest.java +++ b/src/test/java/redis/clients/jedis/tests/JedisTest.java @@ -13,6 +13,7 @@ public class JedisTest extends JedisCommandTestBase { @Test public void useWithoutConnecting() { Jedis jedis = new Jedis("localhost"); + jedis.auth("foobared"); jedis.dbSize(); } diff --git a/src/test/java/redis/clients/jedis/tests/ShardedJedisTest.java b/src/test/java/redis/clients/jedis/tests/ShardedJedisTest.java new file mode 100644 index 0000000..c8f861c --- /dev/null +++ b/src/test/java/redis/clients/jedis/tests/ShardedJedisTest.java @@ -0,0 +1,53 @@ +package redis.clients.jedis.tests; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Assert; +import org.junit.Test; + +import redis.clients.jedis.Jedis; +import redis.clients.jedis.Protocol; +import redis.clients.jedis.ShardedJedis; +import redis.clients.util.ShardInfo; + +public class ShardedJedisTest extends Assert { + @Test + public void checkSharding() throws IOException { + List shards = new ArrayList(); + shards.add(new ShardInfo("localhost", Protocol.DEFAULT_PORT)); + shards.add(new ShardInfo("localhost", Protocol.DEFAULT_PORT + 1)); + ShardedJedis jedis = new ShardedJedis(shards); + ShardInfo s1 = jedis.getShardInfo("a"); + ShardInfo s2 = jedis.getShardInfo("b"); + assertNotSame(s1, s2); + } + + @Test + public void trySharding() throws IOException { + List shards = new ArrayList(); + ShardInfo si = new ShardInfo("localhost", Protocol.DEFAULT_PORT); + si.setPassword("foobared"); + shards.add(si); + si = new ShardInfo("localhost", Protocol.DEFAULT_PORT + 1); + si.setPassword("foobared"); + shards.add(si); + ShardedJedis jedis = new ShardedJedis(shards); + jedis.set("a", "bar"); + ShardInfo s1 = jedis.getShardInfo("a"); + jedis.set("b", "bar1"); + ShardInfo s2 = jedis.getShardInfo("b"); + jedis.disconnect(); + + Jedis j = new Jedis(s1.getHost(), s1.getPort()); + j.auth("foobared"); + assertEquals("bar", j.get("a")); + j.disconnect(); + + j = new Jedis(s2.getHost(), s2.getPort()); + j.auth("foobared"); + assertEquals("bar1", j.get("b")); + j.disconnect(); + } +} \ No newline at end of file