欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Java时间轮算法的实现代码示例

程序员文章站 2024-02-22 17:11:34
考虑这样一个场景,现在有5000个任务,要让这5000个任务每隔5分中触发某个操作,怎么去实现这个需求。大部分人首先想到的是使用定时器,但是5000个任务,你就要用5000...

考虑这样一个场景,现在有5000个任务,要让这5000个任务每隔5分中触发某个操作,怎么去实现这个需求。大部分人首先想到的是使用定时器,但是5000个任务,你就要用5000个定时器,一个定时器就是一个线程,你懂了吧,这种方法肯定是不行的。

针对这个场景,催生了时间轮算法,时间轮到底是什么?我一贯的风格,自行谷歌去。大发慈悲,发个时间轮介绍你们看看,看文字和图就好了,代码不要看了,那个文章里的代码运行不起来,时间轮介绍。

看好了介绍,我们就开始动手吧。

开发环境:idea + jdk1.8 + maven

新建一个maven工程

Java时间轮算法的实现代码示例 

创建如下的目录结构

Java时间轮算法的实现代码示例 

不要忘了pom.xml中添加netty库

<dependencies>
    <dependency>
      <groupid>io.netty</groupid>
      <artifactid>netty-all</artifactid>
      <version>4.1.5.final</version>
    </dependency>
  </dependencies>

代码如下

timeout.java

package com.tanghuachun.timer;
public interface timeout {
  timer timer();
  timertask task();
  boolean isexpired();
  boolean iscancelled();
  boolean cancel();
}

timer.java

package com.tanghuachun.timer;
import java.util.set;
import java.util.concurrent.timeunit;

public interface timer {
  timeout newtimeout(timertask task, long delay, timeunit unit, string argv);
  set<timeout> stop();
}

timertask.java

package com.tanghuachun.timer;
public interface timertask {
  void run(timeout timeout, string argv) throws exception;
}

timerwheel.java

/*
 * copyright 2012 the netty project
 *
 * the netty project licenses this file to you under the apache license,
 * version 2.0 (the "license"); you may not use this file except in compliance
 * with the license. you may obtain a copy of the license at:
 *
 *  http://www.apache.org/licenses/license-2.0
 *
 * unless required by applicable law or agreed to in writing, software
 * distributed under the license is distributed on an "as is" basis, without
 * warranties or conditions of any kind, either express or implied. see the
 * license for the specific language governing permissions and limitations
 * under the license.
 */
package com.tanghuachun.timer;
import io.netty.util.*;
import io.netty.util.internal.platformdependent;
import io.netty.util.internal.stringutil;
import io.netty.util.internal.logging.internallogger;
import io.netty.util.internal.logging.internalloggerfactory;
import java.util.collections;
import java.util.hashset;
import java.util.queue;
import java.util.set;
import java.util.concurrent.countdownlatch;
import java.util.concurrent.executors;
import java.util.concurrent.threadfactory;
import java.util.concurrent.timeunit;
import java.util.concurrent.atomic.atomicintegerfieldupdater;

public class timerwheel implements timer {

  static final internallogger logger =
      internalloggerfactory.getinstance(timerwheel.class);

  private static final resourceleakdetector<timerwheel> leakdetector = resourceleakdetectorfactory.instance()
      .newresourceleakdetector(timerwheel.class, 1, runtime.getruntime().availableprocessors() * 4l);

  private static final atomicintegerfieldupdater<timerwheel> worker_state_updater;
  static {
    atomicintegerfieldupdater<timerwheel> workerstateupdater =
        platformdependent.newatomicintegerfieldupdater(timerwheel.class, "workerstate");
    if (workerstateupdater == null) {
      workerstateupdater = atomicintegerfieldupdater.newupdater(timerwheel.class, "workerstate");
    }
    worker_state_updater = workerstateupdater;
  }

  private final resourceleak leak;
  private final worker worker = new worker();
  private final thread workerthread;

  public static final int worker_state_init = 0;
  public static final int worker_state_started = 1;
  public static final int worker_state_shutdown = 2;
  @suppresswarnings({ "unused", "fieldmaybefinal", "redundantfieldinitialization" })
  private volatile int workerstate = worker_state_init; // 0 - init, 1 - started, 2 - shut down

  private final long tickduration;
  private final hashedwheelbucket[] wheel;
  private final int mask;
  private final countdownlatch starttimeinitialized = new countdownlatch(1);
  private final queue<hashedwheeltimeout> timeouts = platformdependent.newmpscqueue();
  private final queue<hashedwheeltimeout> cancelledtimeouts = platformdependent.newmpscqueue();

  private volatile long starttime;

  /**
   * creates a new timer with the default thread factory
   * ({@link executors#defaultthreadfactory()}), default tick duration, and
   * default number of ticks per wheel.
   */
  public timerwheel() {
    this(executors.defaultthreadfactory());
  }

  /**
   * creates a new timer with the default thread factory
   * ({@link executors#defaultthreadfactory()}) and default number of ticks
   * per wheel.
   *
   * @param tickduration  the duration between tick
   * @param unit      the time unit of the {@code tickduration}
   * @throws nullpointerexception   if {@code unit} is {@code null}
   * @throws illegalargumentexception if {@code tickduration} is <= 0
   */
  public timerwheel(long tickduration, timeunit unit) {
    this(executors.defaultthreadfactory(), tickduration, unit);
  }

  /**
   * creates a new timer with the default thread factory
   * ({@link executors#defaultthreadfactory()}).
   *
   * @param tickduration  the duration between tick
   * @param unit      the time unit of the {@code tickduration}
   * @param ticksperwheel the size of the wheel
   * @throws nullpointerexception   if {@code unit} is {@code null}
   * @throws illegalargumentexception if either of {@code tickduration} and {@code ticksperwheel} is <= 0
   */
  public timerwheel(long tickduration, timeunit unit, int ticksperwheel) {
    this(executors.defaultthreadfactory(), tickduration, unit, ticksperwheel);
  }

  /**
   * creates a new timer with the default tick duration and default number of
   * ticks per wheel.
   *
   * @param threadfactory a {@link threadfactory} that creates a
   *            background {@link thread} which is dedicated to
   *            {@link timertask} execution.
   * @throws nullpointerexception if {@code threadfactory} is {@code null}
   */
  public timerwheel(threadfactory threadfactory) {
    this(threadfactory, 100, timeunit.milliseconds);
  }

  /**
   * creates a new timer with the default number of ticks per wheel.
   *
   * @param threadfactory a {@link threadfactory} that creates a
   *            background {@link thread} which is dedicated to
   *            {@link timertask} execution.
   * @param tickduration  the duration between tick
   * @param unit      the time unit of the {@code tickduration}
   * @throws nullpointerexception   if either of {@code threadfactory} and {@code unit} is {@code null}
   * @throws illegalargumentexception if {@code tickduration} is <= 0
   */
  public timerwheel(
      threadfactory threadfactory, long tickduration, timeunit unit) {
    this(threadfactory, tickduration, unit, 512);
  }

  /**
   * creates a new timer.
   *
   * @param threadfactory a {@link threadfactory} that creates a
   *            background {@link thread} which is dedicated to
   *            {@link timertask} execution.
   * @param tickduration  the duration between tick
   * @param unit      the time unit of the {@code tickduration}
   * @param ticksperwheel the size of the wheel
   * @throws nullpointerexception   if either of {@code threadfactory} and {@code unit} is {@code null}
   * @throws illegalargumentexception if either of {@code tickduration} and {@code ticksperwheel} is <= 0
   */
  public timerwheel(
      threadfactory threadfactory,
      long tickduration, timeunit unit, int ticksperwheel) {
    this(threadfactory, tickduration, unit, ticksperwheel, true);
  }

  /**
   * creates a new timer.
   *
   * @param threadfactory a {@link threadfactory} that creates a
   *            background {@link thread} which is dedicated to
   *            {@link timertask} execution.
   * @param tickduration  the duration between tick
   * @param unit      the time unit of the {@code tickduration}
   * @param ticksperwheel the size of the wheel
   * @param leakdetection {@code true} if leak detection should be enabled always, if false it will only be enabled
   *            if the worker thread is not a daemon thread.
   * @throws nullpointerexception   if either of {@code threadfactory} and {@code unit} is {@code null}
   * @throws illegalargumentexception if either of {@code tickduration} and {@code ticksperwheel} is <= 0
   */
  public timerwheel(
      threadfactory threadfactory,
      long tickduration, timeunit unit, int ticksperwheel, boolean leakdetection) {

    if (threadfactory == null) {
      throw new nullpointerexception("threadfactory");
    }
    if (unit == null) {
      throw new nullpointerexception("unit");
    }
    if (tickduration <= 0) {
      throw new illegalargumentexception("tickduration must be greater than 0: " + tickduration);
    }
    if (ticksperwheel <= 0) {
      throw new illegalargumentexception("ticksperwheel must be greater than 0: " + ticksperwheel);
    }

    // normalize ticksperwheel to power of two and initialize the wheel.
    wheel = createwheel(ticksperwheel);
    mask = wheel.length - 1;

    // convert tickduration to nanos.
    this.tickduration = unit.tonanos(tickduration);

    // prevent overflow.
    if (this.tickduration >= long.max_value / wheel.length) {
      throw new illegalargumentexception(string.format(
          "tickduration: %d (expected: 0 < tickduration in nanos < %d",
          tickduration, long.max_value / wheel.length));
    }
    workerthread = threadfactory.newthread(worker);

    leak = leakdetection || !workerthread.isdaemon() ? leakdetector.open(this) : null;
  }

  private static hashedwheelbucket[] createwheel(int ticksperwheel) {
    if (ticksperwheel <= 0) {
      throw new illegalargumentexception(
          "ticksperwheel must be greater than 0: " + ticksperwheel);
    }
    if (ticksperwheel > 1073741824) {
      throw new illegalargumentexception(
          "ticksperwheel may not be greater than 2^30: " + ticksperwheel);
    }

    ticksperwheel = normalizeticksperwheel(ticksperwheel);
    hashedwheelbucket[] wheel = new hashedwheelbucket[ticksperwheel];
    for (int i = 0; i < wheel.length; i ++) {
      wheel[i] = new hashedwheelbucket();
    }
    return wheel;
  }

  private static int normalizeticksperwheel(int ticksperwheel) {
    int normalizedticksperwheel = 1;
    while (normalizedticksperwheel < ticksperwheel) {
      normalizedticksperwheel <<= 1;
    }
    return normalizedticksperwheel;
  }

  /**
   * starts the background thread explicitly. the background thread will
   * start automatically on demand even if you did not call this method.
   *
   * @throws illegalstateexception if this timer has been
   *                {@linkplain #stop() stopped} already
   */
  public void start() {
    switch (worker_state_updater.get(this)) {
      case worker_state_init:
        if (worker_state_updater.compareandset(this, worker_state_init, worker_state_started)) {
          workerthread.start();
        }
        break;
      case worker_state_started:
        break;
      case worker_state_shutdown:
        throw new illegalstateexception("cannot be started once stopped");
      default:
        throw new error("invalid workerstate");
    }

    // wait until the starttime is initialized by the worker.
    while (starttime == 0) {
      try {
        starttimeinitialized.await();
      } catch (interruptedexception ignore) {
        // ignore - it will be ready very soon.
      }
    }
  }

  @override
  public set<timeout> stop() {
    if (thread.currentthread() == workerthread) {
      throw new illegalstateexception(
          timerwheel.class.getsimplename() +
              ".stop() cannot be called from " +
              timertask.class.getsimplename());
    }

    if (!worker_state_updater.compareandset(this, worker_state_started, worker_state_shutdown)) {
      // workerstate can be 0 or 2 at this moment - let it always be 2.
      worker_state_updater.set(this, worker_state_shutdown);

      if (leak != null) {
        leak.close();
      }

      return collections.emptyset();
    }

    boolean interrupted = false;
    while (workerthread.isalive()) {
      workerthread.interrupt();
      try {
        workerthread.join(100);
      } catch (interruptedexception ignored) {
        interrupted = true;
      }
    }

    if (interrupted) {
      thread.currentthread().interrupt();
    }

    if (leak != null) {
      leak.close();
    }
    return worker.unprocessedtimeouts();
  }

  @override
  public timeout newtimeout(timertask task, long delay, timeunit unit, string argv) {
    if (task == null) {
      throw new nullpointerexception("task");
    }
    if (unit == null) {
      throw new nullpointerexception("unit");
    }
    start();

    // add the timeout to the timeout queue which will be processed on the next tick.
    // during processing all the queued hashedwheeltimeouts will be added to the correct hashedwheelbucket.
    long deadline = system.nanotime() + unit.tonanos(delay) - starttime;
    hashedwheeltimeout timeout = new hashedwheeltimeout(this, task, deadline, argv);
    timeouts.add(timeout);
    return timeout;
  }

  private final class worker implements runnable {
    private final set<timeout> unprocessedtimeouts = new hashset<timeout>();

    private long tick;

    @override
    public void run() {
      // initialize the starttime.
      starttime = system.nanotime();
      if (starttime == 0) {
        // we use 0 as an indicator for the uninitialized value here, so make sure it's not 0 when initialized.
        starttime = 1;
      }

      // notify the other threads waiting for the initialization at start().
      starttimeinitialized.countdown();

      do {
        final long deadline = waitfornexttick();
        if (deadline > 0) {
          int idx = (int) (tick & mask);
          processcancelledtasks();
          hashedwheelbucket bucket =
              wheel[idx];
          transfertimeoutstobuckets();
          bucket.expiretimeouts(deadline);
          tick++;
        }
      } while (worker_state_updater.get(timerwheel.this) == worker_state_started);

      // fill the unprocessedtimeouts so we can return them from stop() method.
      for (hashedwheelbucket bucket: wheel) {
        bucket.cleartimeouts(unprocessedtimeouts);
      }
      for (;;) {
        hashedwheeltimeout timeout = timeouts.poll();
        if (timeout == null) {
          break;
        }
        if (!timeout.iscancelled()) {
          unprocessedtimeouts.add(timeout);
        }
      }
      processcancelledtasks();
    }

    private void transfertimeoutstobuckets() {
      // transfer only max. 100000 timeouts per tick to prevent a thread to stale the workerthread when it just
      // adds new timeouts in a loop.
      for (int i = 0; i < 100000; i++) {
        hashedwheeltimeout timeout = timeouts.poll();
        if (timeout == null) {
          // all processed
          break;
        }
        if (timeout.state() == hashedwheeltimeout.st_cancelled) {
          // was cancelled in the meantime.
          continue;
        }

        long calculated = timeout.deadline / tickduration;
        timeout.remainingrounds = (calculated - tick) / wheel.length;

        final long ticks = math.max(calculated, tick); // ensure we don't schedule for past.
        int stopindex = (int) (ticks & mask);

        hashedwheelbucket bucket = wheel[stopindex];
        bucket.addtimeout(timeout);
      }
    }

    private void processcancelledtasks() {
      for (;;) {
        hashedwheeltimeout timeout = cancelledtimeouts.poll();
        if (timeout == null) {
          // all processed
          break;
        }
        try {
          timeout.remove();
        } catch (throwable t) {
          if (logger.iswarnenabled()) {
            logger.warn("an exception was thrown while process a cancellation task", t);
          }
        }
      }
    }

    /**
     * calculate goal nanotime from starttime and current tick number,
     * then wait until that goal has been reached.
     * @return long.min_value if received a shutdown request,
     * current time otherwise (with long.min_value changed by +1)
     */
    private long waitfornexttick() {
      long deadline = tickduration * (tick + 1);

      for (;;) {
        final long currenttime = system.nanotime() - starttime;
        long sleeptimems = (deadline - currenttime + 999999) / 1000000;

        if (sleeptimems <= 0) {
          if (currenttime == long.min_value) {
            return -long.max_value;
          } else {
            return currenttime;
          }
        }

        // check if we run on windows, as if thats the case we will need
        // to round the sleeptime as workaround for a bug that only affect
        // the jvm if it runs on windows.
        //
        // see https://github.com/netty/netty/issues/356
        if (platformdependent.iswindows()) {
          sleeptimems = sleeptimems / 10 * 10;
        }

        try {
          thread.sleep(sleeptimems);
        } catch (interruptedexception ignored) {
          if (worker_state_updater.get(timerwheel.this) == worker_state_shutdown) {
            return long.min_value;
          }
        }
      }
    }

    public set<timeout> unprocessedtimeouts() {
      return collections.unmodifiableset(unprocessedtimeouts);
    }
  }

  private static final class hashedwheeltimeout implements timeout {

    private static final int st_init = 0;
    private static final int st_cancelled = 1;
    private static final int st_expired = 2;
    private static final atomicintegerfieldupdater<hashedwheeltimeout> state_updater;

    static {
      atomicintegerfieldupdater<hashedwheeltimeout> updater =
          platformdependent.newatomicintegerfieldupdater(hashedwheeltimeout.class, "state");
      if (updater == null) {
        updater = atomicintegerfieldupdater.newupdater(hashedwheeltimeout.class, "state");
      }
      state_updater = updater;
    }

    private final timerwheel timer;
    private final timertask task;
    private final long deadline;

    @suppresswarnings({"unused", "fieldmaybefinal", "redundantfieldinitialization" })
    private volatile int state = st_init;

    // remainingrounds will be calculated and set by worker.transfertimeoutstobuckets() before the
    // hashedwheeltimeout will be added to the correct hashedwheelbucket.
    long remainingrounds;
    string argv;

    // this will be used to chain timeouts in hashedwheeltimerbucket via a double-linked-list.
    // as only the workerthread will act on it there is no need for synchronization / volatile.
    hashedwheeltimeout next;
    hashedwheeltimeout prev;

    // the bucket to which the timeout was added
    hashedwheelbucket bucket;

    hashedwheeltimeout(timerwheel timer, timertask task, long deadline, string argv) {
      this.timer = timer;
      this.task = task;
      this.deadline = deadline;
      this.argv = argv;

    }

    @override
    public timer timer() {
      return timer;
    }

    @override
    public timertask task() {
      return task;
    }

    @override
    public boolean cancel() {
      // only update the state it will be removed from hashedwheelbucket on next tick.
      if (!compareandsetstate(st_init, st_cancelled)) {
        return false;
      }
      // if a task should be canceled we put this to another queue which will be processed on each tick.
      // so this means that we will have a gc latency of max. 1 tick duration which is good enough. this way
      // we can make again use of our mpsclinkedqueue and so minimize the locking / overhead as much as possible.
      timer.cancelledtimeouts.add(this);
      return true;
    }

    void remove() {
      hashedwheelbucket bucket = this.bucket;
      if (bucket != null) {
        bucket.remove(this);
      }
    }

    public boolean compareandsetstate(int expected, int state) {
      return state_updater.compareandset(this, expected, state);
    }

    public int state() {
      return state;
    }

    @override
    public boolean iscancelled() {
      return state() == st_cancelled;
    }

    @override
    public boolean isexpired() {
      return state() == st_expired;
    }

    public void expire() {
      if (!compareandsetstate(st_init, st_expired)) {
        return;
      }

      try {
        task.run(this, argv);
      } catch (throwable t) {
        if (logger.iswarnenabled()) {
          logger.warn("an exception was thrown by " + timertask.class.getsimplename() + '.', t);
        }
      }
    }

    @override
    public string tostring() {
      final long currenttime = system.nanotime();
      long remaining = deadline - currenttime + timer.starttime;

      stringbuilder buf = new stringbuilder(192)
          .append(stringutil.simpleclassname(this))
          .append('(')
          .append("deadline: ");
      if (remaining > 0) {
        buf.append(remaining)
            .append(" ns later");
      } else if (remaining < 0) {
        buf.append(-remaining)
            .append(" ns ago");
      } else {
        buf.append("now");
      }

      if (iscancelled()) {
        buf.append(", cancelled");
      }

      return buf.append(", task: ")
          .append(task())
          .append(')')
          .tostring();
    }
  }

  /**
   * bucket that stores hashedwheeltimeouts. these are stored in a linked-list like datastructure to allow easy
   * removal of hashedwheeltimeouts in the middle. also the hashedwheeltimeout act as nodes themself and so no
   * extra object creation is needed.
   */
  private static final class hashedwheelbucket {
    // used for the linked-list datastructure
    private hashedwheeltimeout head;
    private hashedwheeltimeout tail;

    /**
     * add {@link hashedwheeltimeout} to this bucket.
     */
    public void addtimeout(hashedwheeltimeout timeout) {
      assert timeout.bucket == null;
      timeout.bucket = this;
      if (head == null) {
        head = tail = timeout;
      } else {
        tail.next = timeout;
        timeout.prev = tail;
        tail = timeout;
      }
    }

    /**
     * expire all {@link hashedwheeltimeout}s for the given {@code deadline}.
     */
    public void expiretimeouts(long deadline) {
      hashedwheeltimeout timeout = head;

      // process all timeouts
      while (timeout != null) {
        boolean remove = false;
        if (timeout.remainingrounds <= 0) {
          if (timeout.deadline <= deadline) {
            timeout.expire();
          } else {
            // the timeout was placed into a wrong slot. this should never happen.
            throw new illegalstateexception(string.format(
                "timeout.deadline (%d) > deadline (%d)", timeout.deadline, deadline));
          }
          remove = true;
        } else if (timeout.iscancelled()) {
          remove = true;
        } else {
          timeout.remainingrounds --;
        }
        // store reference to next as we may null out timeout.next in the remove block.
        hashedwheeltimeout next = timeout.next;
        if (remove) {
          remove(timeout);
        }
        timeout = next;
      }
    }

    public void remove(hashedwheeltimeout timeout) {
      hashedwheeltimeout next = timeout.next;
      // remove timeout that was either processed or cancelled by updating the linked-list
      if (timeout.prev != null) {
        timeout.prev.next = next;
      }
      if (timeout.next != null) {
        timeout.next.prev = timeout.prev;
      }

      if (timeout == head) {
        // if timeout is also the tail we need to adjust the entry too
        if (timeout == tail) {
          tail = null;
          head = null;
        } else {
          head = next;
        }
      } else if (timeout == tail) {
        // if the timeout is the tail modify the tail to be the prev node.
        tail = timeout.prev;
      }
      // null out prev, next and bucket to allow for gc.
      timeout.prev = null;
      timeout.next = null;
      timeout.bucket = null;
    }

    /**
     * clear this bucket and return all not expired / cancelled {@link timeout}s.
     */
    public void cleartimeouts(set<timeout> set) {
      for (;;) {
        hashedwheeltimeout timeout = polltimeout();
        if (timeout == null) {
          return;
        }
        if (timeout.isexpired() || timeout.iscancelled()) {
          continue;
        }
        set.add(timeout);
      }
    }

    private hashedwheeltimeout polltimeout() {
      hashedwheeltimeout head = this.head;
      if (head == null) {
        return null;
      }
      hashedwheeltimeout next = head.next;
      if (next == null) {
        tail = this.head = null;
      } else {
        this.head = next;
        next.prev = null;
      }

      // null out prev and next to allow for gc.
      head.next = null;
      head.prev = null;
      head.bucket = null;
      return head;
    }
  }
}

编写测试类main.java

package com.tanghuachun.timer;
import java.util.concurrent.timeunit;

/**
 * created by darren on 2016/11/17.
 */
public class main implements timertask{
  final static timer timer = new timerwheel();


  public static void main(string[] args) {
    timertask timertask = new main();
    for (int i = 0; i < 10; i++) {
      timer.newtimeout(timertask, 5, timeunit.seconds, "" + i );
    }
  }
  @override
  public void run(timeout timeout, string argv) throws exception {
    system.out.println("timeout, argv = " + argv );
  }
}

然后就可以看到运行结果啦。

(以maven的方式导入)。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。