欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Twitter的雪花算法(snowflake)自增ID

程序员文章站 2022-07-13 09:59:15
...

什么是雪花算法 snowflake

https://segmentfault.com/a/1190000011282426

前言

这个问题源自于,我想找一个分布式下的ID生成器。
  这个最简单的方案是,数据库自增ID。为啥不用咧?有这么几点原因,一是,会依赖于数据库的具体实现,比如,mysql有自增,oracle没有,得用序列,mongo似乎也没有他自己有个什么ID,sqlserver貌似有自增等等,有些不稳定因素,因为ID生成是业务的核心基础。当然,还有就是性能,自增ID是连续的,它就依赖于数据库自身的锁,所以数据库就有瓶颈。当然了,多台数据库加某种间隔也是可用的,但是,运维维护会很复杂,因为它不是内聚的解决方案。而且,很难提前获得下一个ID。
  后来,我用过一段时间在数据库表里进行记录来进行自增。这个的优势是,我可以提前获得下一个ID,而且,某个进程里可以一次获取一批,减少锁的依赖,虽然进程间的不重复依然是基于数据库事务隔离的,但是,依赖小了,瓶颈小了。这个方案其实挺好的,我依然也会继续用,主要是,它可以生成数字字母混合的编剧号,而且基本可控。但是,我数据库主键为了效率和空间成本,基本会选用long,基本顺序生成就可以了,所以,使用这种带持久化的方案,会显得很重。起项目的时候,也是,需要先建立对应的表,然后再把代码或者jar包引进去,然后再用,比较重。最好就是能够直接生成,没有那么多依赖。
  然后,我从我上司那里听到了twitter的这个算法。其实,我上司有个实现,我这个就是基于他的改的,但是,他的有两个值是配置的,我还是嫌麻烦,于是就动手把那两个值变成了从机器与进程获取,就有了这个版本。

思路

说实话,我也就听了这么个算法的名字,没正经看过原算法,但是,我上司说他代码是网上抄的,所以,这个算法名字我还是不敢丢,下面我们说说整体的思路。
  整个ID的构成大概分为这么几个部分,时间戳差值,机器编码,进程编码,***。java的long是64位的从左向右依次介绍是:时间戳差值,在我们这里占了42位;机器编码5位;进程编码5位;***12位。所有的拼接用位运算拼接起来,于是就基本做到了每个进程中不会重复了。

代码(这里是我修改原po的之后的代码)

原博:https://blog.csdn.net/linghuanxu/article/details/78896317

code1

package com.zgd.demo.util.idGenerator;

import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.util.Enumeration;

import static java.lang.System.currentTimeMillis;

/**
 * 雪花算法 snowflake 算法自增id生成器
 * @author zgd
 * @time 2018年8月8日10:45:08
 */
public class SnowFlakeIdGenerator {
    //机器启动时的时间戳,需要从机器启动时确定
    private final static long twepoch = currentTimeMillis();

    // 机器标识位数
    private final static long workerIdBits = 5L;
    // 数据中心标识位数
    private final static long datacenterIdBits = 5L;

    // 毫秒内自增位数
    private final static long sequenceBits = 12L;
    // 机器ID偏左移12位
    private final static long workerIdShift = sequenceBits;
    // 数据中心ID左移17位
    private final static long datacenterIdShift = sequenceBits + workerIdBits;
    // 时间毫秒左移22位
    private final static long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
    //sequence掩码,确保sequnce不会超出上限
    private final static long sequenceMask = -1L ^ (-1L << sequenceBits);
    //上次时间戳
    private static long lastTimestamp = -1L;
    //序列
    private long sequence = 0L;
    //服务器ID
    private long workerId = 1L;
	
	/**
	* long workerIdBits = 5L;
	* -1L 的二进制: 1111111111111111111111111111111111111111111111111111111111111111
	* -1L<<workerIdBits = -32 ,二进制: 1111111111111111111111111111111111111111111111111111111111100000
	*  workerMask= -1L ^ -32 = 31, 二进制: 11111
	*/
    private static long workerMask= -1L ^ (-1L << workerIdBits);
    //进程编码
    private long processId = 1L;
    private static long processMask=-1L ^ (-1L << datacenterIdBits);
    private static SnowFlakeIdGenerator idGenerator = null;

    static{
        idGenerator=new SnowFlakeIdGenerator();
    }
    public static synchronized long nextId(){
        return idGenerator.getNextId();
    }

    /**
     * 隐藏构造方法,单例
     */
    private SnowFlakeIdGenerator() {
        System.out.println("实例化了SnowFlakeIdGenerator");
        //获取机器编码
        this.workerId=this.getMachineNum();
        //获取进程编码
        RuntimeMXBean runtimeMXBean = ManagementFactory.getRuntimeMXBean();
        this.processId=Long.valueOf(runtimeMXBean.getName().split("@")[0]).longValue();

        //避免编码超出最大值
        /**
        * 如果workerId=489181L,二进制是1110111011011011101, workerMask上面已经得知是31, 二进制: 11111, 机器码的位数是5位,workerIdBits = 5L;
        * workerId & workerMask = 29,二进制 11101,所以二进制都是控制在5位
        * processId同理可以控制在二进制5位
        */
        this.workerId=workerId & workerMask;
        this.processId=processId & processMask;
    }

    public synchronized long getNextId() {
        //获取时间戳
        long timestamp = timeGen();
        //如果时间戳小于上次时间戳则报错
        if (timestamp < lastTimestamp) {
            try {
                throw new Exception("Clock moved backwards.  Refusing to generate id for " + (lastTimestamp - timestamp) + " milliseconds");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        //如果时间戳与上次时间戳相同
        if (lastTimestamp == timestamp) {
            // 当前毫秒内,则+1,与sequenceMask确保sequence不会超出上限
            sequence = (sequence + 1) & sequenceMask;
            if (sequence == 0) {
                // 当前毫秒内计数满了,则等待下一秒
                timestamp = tilNextMillis(lastTimestamp);
            }
        } else {
            sequence = 0;
        }
        lastTimestamp = timestamp;
        // ID偏移组合生成最终的ID,并返回ID
        //将时间戳的二进制向左移22位,进程id的二进制向左移12位,机器id的二进制左移5位,序号的二进制左移5位
        //再用||连接起来,所以形成了一个64位的long类型:
        //[xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx xx] [xx xxxx xxxx xx] [xx xxx][x xxxx]  从左到右分别是: 时间戳, 进程id,机器id,序号
        long nextId = ((timestamp - twepoch) << timestampLeftShift) | (processId << datacenterIdShift) | (workerId << workerIdShift) | sequence;
        return nextId;
    }

    /**
     * 再次获取时间戳直到获取的时间戳与现有的不同
     * @param lastTimestamp
     * @return 下一个时间戳
     */
    private long tilNextMillis(final long lastTimestamp) {
        long timestamp = this.timeGen();
        while (timestamp <= lastTimestamp) {
            timestamp = this.timeGen();
        }
        return timestamp;
    }

    private long timeGen() {
        return currentTimeMillis();
    }

    /**
     * 获取机器编码
     * @return
     */
    private long getMachineNum(){
        long machinePiece;
        StringBuilder sb = new StringBuilder();
        Enumeration<NetworkInterface> e = null;
        try {
            e = NetworkInterface.getNetworkInterfaces();
        } catch (SocketException e1) {
            e1.printStackTrace();
        }
        while (e.hasMoreElements()) {
            NetworkInterface ni = e.nextElement();
            sb.append(ni.toString());
        }
        machinePiece = sb.toString().hashCode();
        return machinePiece;
    }

}

试了一下,连续调用了三次nextId():

实例化了SnowFlakeIdGenerator
252760064
252760065
252760066
252760067
二进制分别是:
1111000100001101000000000000
1111000100001101000000000001
1111000100001101000000000010
1111000100001101000000000011

还有另一种方法

code2


import java.lang.management.ManagementFactory;
import java.net.InetAddress;
import java.net.NetworkInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Sequence {
    private static final Logger log = LoggerFactory.getLogger(Sequence.class);
    private final long twepoch = 1288834974657L;//这里是机器启动的时间戳
    private final long workerIdBits = 5L;
    private final long datacenterIdBits = 5L;
    private final long maxWorkerId = 31L;
    private final long maxDataCenterId = 31L;
    private final long sequenceBits = 12L;
    private final long workerIdShift = 12L;
    private final long datacenterIdShift = 17L;
    private final long timestampLeftShift = 22L;
    private final long sequenceMask = 4095L;
    private long workerId;
    private long dataCenterId;
    private long sequence = 0L;
    private long lastTimestamp = -1L;
    private long offSet = 5L;

    public Sequence() {
        this.dataCenterId = getDataCenterId(31L);
        this.workerId = getMaxWorkerId(this.dataCenterId, 31L);
    }

    public Sequence(long workerId, long dataCenterId) {
        if (workerId <= 31L && workerId >= 0L) {
            if (dataCenterId <= 31L && dataCenterId >= 0L) {
                this.workerId = workerId;
                this.dataCenterId = dataCenterId;
            } else {
                throw new IllegalArgumentException(String.format("datacenter Id can't be greater than %d or less than 0", 31L));
            }
        } else {
            throw new IllegalArgumentException(String.format("worker Id can't be greater than %d or less than 0", 31L));
        }
    }

    protected static long getMaxWorkerId(long dataCenterId, long maxWorkerId) {
        StringBuilder mpid = new StringBuilder();
        mpid.append(dataCenterId);
        String name = ManagementFactory.getRuntimeMXBean().getName();
        if (name != null && "".equals(name)) {
            mpid.append(name.split("@")[0]);
        }

        return (long)(mpid.toString().hashCode() & '\uffff') % (maxWorkerId + 1L);
    }

    protected static long getDataCenterId(long maxDataCenterId) {
        long id = 0L;

        try {
            InetAddress ip = InetAddress.getLocalHost();
            NetworkInterface network = NetworkInterface.getByInetAddress(ip);
            if (network == null) {
                id = 1L;
            } else {
                byte[] mac = network.getHardwareAddress();
                if (null != mac) {
                    id = (255L & (long)mac[mac.length - 1] | 65280L & (long)mac[mac.length - 2] << 8) >> 6;
                    id %= maxDataCenterId + 1L;
                }
            }
        } catch (Exception var7) {
            log.error(" getDataCenterId: ", var7);
        }

        return id;
    }

    public synchronized long nextId() {
        long timestamp = this.timeGen();
        if (timestamp < this.lastTimestamp) {
            long offset = this.lastTimestamp - timestamp;
            if (offset > this.offSet) {
                throw new RuntimeException(String.format("Clock moved backwards.  Refusing to generate id for %d milliseconds", offset));
            }

            try {
                this.wait(offset << 1);
                timestamp = this.timeGen();
                if (timestamp < this.lastTimestamp) {
                    throw new RuntimeException(String.format("Clock moved backwards.  Refusing to generate id for %d milliseconds", offset));
                }
            } catch (Exception var6) {
                throw new RuntimeException(var6);
            }
        }

        if (this.lastTimestamp == timestamp) {
            this.sequence = this.sequence + 1L & 4095L;
            if (this.sequence == 0L) {
                timestamp = this.tilNextMillis(this.lastTimestamp);
            }
        } else {
            this.sequence = 0L;
        }

        this.lastTimestamp = timestamp;
        return timestamp - twepoch  << 22 | this.dataCenterId << 17 | this.workerId << 12 | this.sequence;
    }

    protected long tilNextMillis(long lastTimestamp) {
        long timestamp;
        for(timestamp = this.timeGen(); timestamp <= lastTimestamp; timestamp = this.timeGen()) {
            ;
        }

        return timestamp;
    }

    protected long timeGen() {
        return SystemClock.now();
    }
}

自己写的测试类

package com.zgd.demo.test;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.zgd.demo.util.idGenerator.SnowFlakeIdGenerator;

import java.util.concurrent.*;

public class TestIdRepeat {
    public static void main(String[] args) throws Exception {
        // 线程数量
        final int threadCount = 100;
        // 每个线程生成的 ID 数量
        final int idCountPerThread = 1000;
        // 用于等待所有线程启动完成
        CountDownLatch threadLatch = new CountDownLatch(threadCount);

        final int coreThread = 5;
        final int maxThread = 50;
        final long keepAliveTime = 0L;
        final int queueCapacity = 1024;


        ThreadFactory namedThreadFactory = new ThreadFactoryBuilder().setNameFormat("demo-pool-%d").build();

        //Common Thread Pool
        ExecutorService pool = new ThreadPoolExecutor(coreThread, maxThread,keepAliveTime, TimeUnit.MILLISECONDS,
                new LinkedBlockingQueue<Runnable>(queueCapacity), namedThreadFactory, new ThreadPoolExecutor.AbortPolicy());



        ConcurrentSkipListSet<Long> ids = new ConcurrentSkipListSet<>();
        for (int i = 0; i < threadCount; ++i) {
            final int n = i;
            pool.execute(() -> {
                // 等待所有线程都运行到这里,然后都继续运行,差不多同时生成 id
                final String threadNum = Thread.currentThread().getName() + "-" + n + "号线程";
                try {
                    threadLatch.await();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(threadNum+"继续执行");
                for (int j = 0; j < idCountPerThread; ++j) {
                    long id = SnowFlakeIdGenerator.nextId();
                    ids.add(id);
                }
            });
            threadLatch.countDown();
        }
        pool.shutdown();
        // 等待 id 生成完成,生成不同数量的 id 时需要调整
        Thread.sleep(2000);
        System.out.println(ids.size());
        //System.out.println(ids);
    }
}


结果

Twitter的雪花算法(snowflake)自增ID
因为ConcurrentSkipListSet是不重复且有序的,所以打印出来是100000个,说明没有重复的id

>github源码<


代码解读

1.整体设计

为了最大程度的减少配置,方便实用,这个模块,我设计成了单例模式。之所以没有直接使用static方法,还是希望可以控制整个模块的生命周期,但是,模块的初始化,我使用了static块,因为它没有任何依赖。
  有个static的nextId方法,可以直接获得下一个ID,这个方法是线程安全的。同时这个模块的使用就是这么简单粗暴,也不用配置bean。

2.ID生成逻辑

我们先看最后一步:long nextId = ((timestamp - twepoch) << timestampLeftShift) | (processId << datacenterIdShift) | (workerId << workerIdShift) | sequence;
  这句话什么意思呢?
  timestamp - twepoch:时间戳减去一个时间戳,获得一个差值。
  ((timestamp - twepoch) << timestampLeftShift):timestampLeftShift是22,这个操作是将这个差值向左移22位,左移空出来的会自动补0,我们就有了22位的空间了。
  后面可以看到三个|符号,与操作会把1都加进来,而我们后面的数也都在各自的位上才有1,那么|操作就把这些数合进来了。
  (processId << datacenterIdShift):进程编码左移datacenterIdShift,这个是17位,而processId最多是5位,于是刚好填满空位
  (workerId << workerIdShift):与进程编码类似,机器编码也是5位,左移12位
  sequence最大12位。

如何确保不超出位数限制

前面的逻辑中,我们说了很多不超出位数限制啥的内容,那么,具体是怎么做到的呢?我们拿workerId举个例子:
  this.workerId=workerId & workerMask;
  这是我们确保workerId不超过5位的语句,什么意思呢?不经常操作位运算真看不懂。我们先看看workerMask是啥。
  private static long workerMask= -1L ^ (-1L << workerIdBits);
  。。。什么意思呀?它先执行的是-1L << workerIdBits,workerIdBits是5。这又是什么意思呢?注意,这是位运算,long用的是补码,-1L,就是64个1,这里使用-1是为了格式化所有位数,<<是左移运算,-1L左移五位,低位补零,也就是左移空出来的会自动补0,于是就低位五位是0,其余是1。然后^这个符号,是异或,也是位运算,位上相同则为0,不通则为1,和-1做异或,则把所有的0和1颠倒了一下。这时候,我们再看,workerId & workerMask,与操作,两个位上都为1的才能唯一,否则为零,workerMask高位都是0,所以,不管workerId高位是什么,都是0,;而workerMask低位都是1,所以,不管workerId低位是什么,都会被保留,于是,我们就控制了workerId的范围。

最后的异常

这里,时间戳,保证了不通毫秒不同,然后机器编码进程编码保证了不同进程不通,再然后,序列,在统一毫秒内,如果获取第二个ID,则***+1,到下一毫秒后重置。至此,唯一性ok。但是,还有问题,***用完了怎么办?代码里的解决方案是,等到下一毫秒。

补充
  其实,这个方案中,机器码和进程编码是可能相同的,只是概率比较小,我们就凑合着用吧。如果有更好地获取这两位的方式,欢迎沟通。