欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

第三章 - 最大化使用Executors

程序员文章站 2022-07-12 19:51:06
...

Executors的一些高级特性

取消任务

当你把任务发送给 executor 后,你可以选择取消这个任务的执行。使用 submit() 方法发送一个 Runnable 对象给一个executor,submit() 方法将返回一个实现了 Future 这个接口类的对象。你可以通过该类的 cancel() 方法取消任务的执行。cancel() 方法接收一个 boolean 值作为参数。如果参数为 true,且executor正在执行这个任务,那么执行任务的线程将被中断。

cancel() 方法返回一个boolean值显示任务是否被取消。

 

任务调度

ThreadPoolExecutor类是接口 Executor 和 ExecutorService的基本实现类。同时Java提供了一个扩展类用以实现任务的调度 - ScheduledThreadPoolExecutor类,通过此类你可以:

  • 延迟执行一个任务
  • 周期性执行任务,包括以一定的频率执行任务和固定的延迟时间执行任务

重载 executor 的方法

你可以通过继承现有的类 (ThreadPoolExecutor 或者 ScheduledThreadPoolExecutor)来实现自定义的executor。如果你继承了 ThreadPoolExecutor 类,你可以重载以下方法:

  • beforeExecute():这个方法在executor里的并行任务执行前被调用。此方法接收了两个参数:即将被执行的Runnable对象以及负责执行的Thread对象。注意这个方法接收的Runnable对象实际是类FutureTask的一个实例,并不是你通过 submit() 方法发送给 executor 的Runnable对象(该对象已经被包装成了FutureTask)。
  • afterExecute():这个方法在executor里的并行任务执行完之后被调用。它接收两个参数:被执行完的Runnable对象以及一个保存可能从任务里抛出的异常。跟beforeExecute()方法很像,Runnable对象是FutureTask类的一个实例。
  • newTaskFor():这个方法创建了一个任务,用来执行你通过submit()方法发送来的Runnable对象。它必须返回一个RunnableFuture接口的实现。默认的,OpenJDK 8 和 Oracle JDK 8 返回的是FutureTask类的实例,但是这在未来的JDK版本中可能不一样。

如果你继承了ScheduledThreadPoolExecutor类,你可以重载decorateTask()方法。这个方法就像上面的newTaskFor(),但它是针对调度任务的。它允许你重载被executor执行的任务。

 

修改一些初始化参数

你也可以修改一些参数来改变executor的行为。其中包括:

  • BlockingQueue<Runnable>:每一个executor内部都维护一个BlockingQueue来保存即将被执行的任务。你可以传入任何一个该接口的实现类。例如,你可以改变executor执行任务的默认顺序。
  • ThreadFactory:你可以传入一个实现了ThreadFactory接口的自定义类,executor会使用这个自定义线程工厂来创建用来执行任务的线程。例如,你的自定义ThredFactory类返回的自定义线程能够保存每个任务的执行时间到日志中。
  • RejectedExecutionHandler:在你调用了shutdown()或shutdownNow()方法后,所有发送给executor的新任务都会被拒绝。你可以传入自定义的实现了RejectedExecutionHandler接口的类来处理这种情况。

第一个例子 - 服务器应用

在第二章中,我们实现了客户端 / 服务器端应用程序。在这个例子中我们将对那个应用程序做如下拓展:

 

  • 引入一个新的请求用来取消已经发送给服务器的请求
  • 每个请求允许传递一个新的参数代表请求的优先权。用来控制请求的执行顺序
  • 服务器能够计算每个用户已经执行的请求总数以及执行的总耗时

 

// 这个类用来记录已经执行的任务总数以及这些任务的总耗时
// 两个变量都是原子变量,因为不同线程需要更新这两个参数的值
public class ExecutorStatistics {
    private AtomicLong executionTime = new AtomicLong(0L);
    private AtomicInteger numTasks = new AtomicInteger(0);

    public void addExecutionTime(long time) {
        executionTime.addAndGet(time);
    }

    public void addTask() {
        numTasks.incrementAndGet();
    }

    @Override
    public String toString() {
        return "Executed Tasks: " + getNumTasks() + 
                ". Execution Time: "+ getExecutionTime();
    }

    public AtomicLong getExecutionTime() {
        return executionTime;
    }

    public AtomicInteger getNumTasks() {
        return numTasks;
    }
}
 

 

 

// 当executor被调用shutdown()或shutdownNow()后,executor会拒绝新提交的任务
// 这个类用来处理这种情况,这里它所做的是往输出流里打印出错信息
public class RejectedTaskController implements RejectedExecutionHandler {
    @Override
    public void rejectedExecution(Runnable task,
                                  ThreadPoolExecutor executor) {
        ConcurrentCommand command = (ConcurrentCommand) task;
        Socket clientSocket = command.getSocket();
        try {
            PrintWriter out = new
                    PrintWriter(clientSocket.getOutputStream(), true);
            String message = "The server is shutting down."
                    + " Your request can not be served."
                    + " Shutting Down: "
                    + executor.isShutdown()
                    + ". Terminated: "
                    + executor.isTerminated()
                    + ". Terminating: "
                    + executor.isTerminating();
            out.println(message);
            out.close();
            clientSocket.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
 

 

 

/**
 * 当一个Runnable对象提交给executor后,executor不是直接执行这个Runnable对象
 * 它创建了一个FutureTask类的实例对象,该对象被executor里的线程所执行
 * 这里我们继承了FutureTask类并实现了Comparable接口,这样每个提交给executor
 * 的任务可以按照一定规则排序(如优先级)
 */
public class ServerTask<V> extends FutureTask<V> implements Comparable<ServerTask<V>> {
    private ConcurrentCommand command;
    public ServerTask(ConcurrentCommand command) {
        super(command, null);
        this.command=command;
    }

    public ConcurrentCommand getCommand() {
        return command;
    }
    public void setCommand(ConcurrentCommand command) {
        this.command = command;
    }

    @Override
    public int compareTo(ServerTask<V> other) {
        return command.compareTo(other.getCommand());
    }
}
 

 

// 重载executor,以根据我们的需要修改它的行为
public class ServerExecutor extends ThreadPoolExecutor {
    // 记录每个任务的执行时间,主键是ServerTask对象(即Runnable对象)
    // 键值是对应的日期
    private ConcurrentHashMap<Runnable, Date> startTimes;

    // 这个变量记录每个用户的统计数据,主键是用户名
    // 键值是ExecutorStatistics对象
    private ConcurrentHashMap<String, ExecutorStatistics>
            executionStatistics;

    private static int CORE_POOL_SIZE =
            Runtime.getRuntime().availableProcessors();
    private static int MAXIMUM_POOL_SIZE =
            Runtime.getRuntime().availableProcessors();
    private static long KEEP_ALIVE_TIME = 10;
    private static RejectedTaskController REJECTED_TASK_CONTROLLER
            = new RejectedTaskController();

    public ServerExecutor() {
        super(CORE_POOL_SIZE, MAXIMUM_POOL_SIZE, KEEP_ALIVE_TIME,
                TimeUnit.SECONDS, new PriorityBlockingQueue<>(),
                REJECTED_TASK_CONTROLLER);
        startTimes = new ConcurrentHashMap<>();
        executionStatistics = new ConcurrentHashMap<>();
    }

    // 每个任务被执行前调用,这里记录每个任务的开始时间
    protected void beforeExecute(Thread t, Runnable r) {
        super.beforeExecute(t, r);
        startTimes.put(r, new Date());
    }

    // 每个任务执行完后调用,这里我们计算执行当前任务的耗时,
    // 并更新用户已被执行的任务总数和总耗时
    @Override
    protected void afterExecute(Runnable r, Throwable t) {
        super.afterExecute(r, t);
        ServerTask<?> task = (ServerTask<?>) r;
        ConcurrentCommand command = task.getCommand();
        if (t == null) {
            if (!task.isCancelled()) {
                // 首先从startTimes中删除此任务的运行开始时间
                Date startDate = startTimes.remove(r);
                Date endDate = new Date();
                long executionTime = endDate.getTime() - startDate.getTime();
                ExecutorStatistics statistics =
                        executionStatistics.computeIfAbsent
                                (command.getUsername(), n -> new ExecutorStatistics());
                statistics.addExecutionTime(executionTime);
                statistics.addTask();

                // 从ConcurrentServer中维护的任务列表中删除此任务,因为它已经完成
                ConcurrentServer.finishTask(command.getUsername(), command);
            } else {
                String message = "The task"
                        + command.hashCode() + "of user"
                        + command.getUsername() + "has been cancelled. ";
                System.out.println(message);
            }
        } else {
            String message = "The exception "
                    + t.getMessage()
                    + " has been thrown.";
            System.out.println(message);
        }
    }

    // 把发送给executor的Runnable对象包装成ServerTask对象,该对象才真正是被线程执行的
    @Override
    protected <T> RunnableFuture<T> newTaskFor(Runnable runnable,
                                               T value) {
        return new ServerTask<T>(runnable);
    }

    public void writeStatistics() {
        for (Map.Entry<String, ExecutorStatistics> entry : executionStatistics.entrySet()) {
            String user = entry.getKey();
            ExecutorStatistics stats = entry.getValue();
            System.out.println(user + ":" + stats);
        }
    }
}
 

 

 

// 指令的抽象类
public abstract class Command {
    protected String[] command;
    public Command (String [] command) {
        this.command=command;
    }
    public abstract String execute ();
}

/**
 * 这是所有指令类的基础类,它包括了一些所有指令的公共行为
 * 1. 调用每个指令类的具体逻辑实现
 * 2. 把指令结果写会给客户端
 * 3. 关闭所有通信中使用的资源
 */
public abstract class ConcurrentCommand extends Command implements Comparable<ConcurrentCommand>, Runnable {
    private String username;
    private byte priority;
    private Socket socket;
    public ConcurrentCommand(Socket socket, String[] command) {
        super(command);
        username=command[1];
        priority=Byte.parseByte(command[2]);
        this.socket=socket;
    }

    @Override
    public abstract String execute();

    @Override
    public void run() {
        String ret = execute();

        try {
            PrintWriter out = new
                    PrintWriter(socket.getOutputStream(),true);
            out.println(ret);
            socket.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public int compareTo(ConcurrentCommand o) {
        return Byte.compare(o.getPriority(), this.getPriority());
    }

    public byte getPriority() {
        return priority;
    }

    public Socket getSocket() {
        return socket;
    }

    public String getUsername() {
        return username;
    }
}

// 对应Query请求的指令类
public class ConcurrentQueryCommand extends ConcurrentCommand {
    public ConcurrentQueryCommand(Socket socket, String [] command) {
        super(socket, command);
    }

    public String execute() {
        WDIDAO dao=WDIDAO.getDAO();
        if (command.length==5) {
            return dao.query(command[3], command[4]);
        } else if (command.length==6) {
            try {
                return dao.query(command[3], command[4],
                        Short.parseShort(command[5]));
            } catch (NumberFormatException e) {
                return "ERROR;Bad Command";
            }
        } else {
            return "ERROR;Bad Command";
        }
    }
}

//对应Report请求的指令类
public class ConcurrentReportCommand extends ConcurrentCommand {
    public ConcurrentReportCommand(Socket socket, String [] command) {
        super(socket, command);
    }

    public String execute() {
        WDIDAO dao=WDIDAO.getDAO();
        return dao.report(command[3]);
    }
}

//对应Stop请求的指令类
public class ConcurrentStopCommand extends ConcurrentCommand {
    public ConcurrentStopCommand(Socket socket, String [] command) {
        super(socket, command);
    }

    public String execute() {
        ConcurrentServer.shutdown();
        return "Server stopped";
    }
}

//对应Cancel请求的指令类
public class ConcurrentCancelCommand extends ConcurrentCommand {
    public ConcurrentCancelCommand(Socket socket, String [] command) {
        super(socket, command);
    }

    public String execute() {
        ConcurrentServer.cancelTasks(getUsername());
        return message;
    }
}

//此类处理一些服务器不支持的请求
public class ConcurrentErrorCommand extends ConcurrentCommand {
    public ConcurrentErrorCommand(Socket socket, String [] command) {
        super(socket, command);
    }

    public String execute() {
        return "Unknown command: " + command[0];
    }
}

//对应status请求的指令
public class ConcurrentStatusCommand extends ConcurrentCommand {
    public ConcurrentStatusCommand (Socket socket, String[] command) {
        super(socket, command);
    }

    @Override
    public String execute() {
        StringBuilder sb=new StringBuilder();
        ThreadPoolExecutor executor = ConcurrentServer.getExecutor();
        sb.append("Server Status;");
        sb.append("Actived Threads: ");
        sb.append(executor.getActiveCount());
        sb.append(";");
        sb.append("Maximum Pool Size: ");
        sb.append(executor.getMaximumPoolSize());
        sb.append(";");
        sb.append("Core Pool Size: ");
        sb.append(executor.getCorePoolSize());
        sb.append(";");
        sb.append("Pool Size: ");
        sb.append(executor.getPoolSize());
        sb.append(";");
        sb.append("Largest Pool Size: ");
        sb.append(executor.getLargestPoolSize());
        sb.append(";");
        sb.append("Completed Task Count: ");
        sb.append(executor.getCompletedTaskCount());
        sb.append(";");
        sb.append("Task Count: ");
        sb.append(executor.getTaskCount());
        sb.append(";");
        sb.append("Queue Size: ");
        sb.append(executor.getQueue().size());
        sb.append(";");
        return sb.toString();
    }
}
 

 

/**
 * 在这个类中,我们启动了RequestTask这个线程,该线程读取由ConcurrentServer保存的客户端socket,
 * 创建相应的指令并发给executor执行。
 * 这样做的目的是让被线程执行的每个任务只包含和请求相关的代码,其它操作可以在executor外处理
 */
public class ConcurrentServer {
    private static volatile boolean stopped = false;

    // 用来保存发送消息给服务器的客户的sockets
    private static LinkedBlockingQueue<Socket> pendingConnections;

    // 保存每一个在executor里执行的任务所关联的Future对象,主键是用户名,
    // 键值是另外一个ConcurrentMap (它的主键是ConcurrentCommand,键值是和任务关联的Future实例)
    private static ConcurrentMap<String, ConcurrentMap<ConcurrentCommand, ServerTask<?>>> taskController;

    // 执行RequestTask对象的Thread
    private static Thread requestThread;

    // 创建指令对象并发送给executor
    private static RequestTask task;

    private static ServerSocket serverSocket;

    public static void main(String[] args) {
        pendingConnections = new LinkedBlockingQueue<>();
        taskController = new ConcurrentHashMap<String, ConcurrentHashMap<ConcurrentCommand, Future<?>>>();

        // 启动RequestTask线程
        task = new RequestTask(pendingConnections, taskController);
        requestThread = new Thread(task);
        requestThread.start();

        System.out.println("Initialization completed.");

        serverSocket = new ServerSocket(Constants.CONCURRENT_PORT);
        do {
            try {
                Socket clientSocket = serverSocket.accept();
                pendingConnections.put(clientSocket);
            } catch (Exception e) {
                e.printStackTrace();
            }
        } while (!stopped);
        finishServer();
    }

    // 该方法修改stopped变量为true并关闭serverSocket
    public static void shutdown() {
        stopped = true;
        try {
            serverSocket.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    // 该方法停止executor并中断RequestTask线程
    private static void finishServer() {
        System.out.println("Shutting down the server...");
        task.shutdown();
        System.out.println("Shutting down Request task");
        requestThread.interrupt();
        System.out.println("Request task ok");
        System.out.println("Closing socket");
        System.out.println("Shutting down logger");
        System.out.println("Logger ok");
        System.out.println("Main server thread ended");
    }

    // 取消一个用户的请求
    public static void cancelTasks(String username) {
        ConcurrentMap<ConcurrentCommand, ServerTask<?>> userTasks = taskController.get(username);
        if (userTasks == null) {
            return;
        }
        int taskNumber = 0;
        Iterator<ServerTask<?>> it = userTasks.values().iterator();
        while(it.hasNext()) {
            ServerTask<?> task = it.next();
            ConcurrentCommand command = task.getCommand();
            if(!(command instanceof ConcurrentCancelCommand) &&
                    task.cancel(true)) {
                taskNumber++;
                it.remove();
            }
        }
    }

    // 当一个任务顺利完成后,我们需要把和这个任务关联的Future对象从保存的ConcurrentMap中删除
    public static void finishTask(String username, ConcurrentCommand command) {
        ConcurrentMap<ConcurrentCommand, ServerTask<?>> userTasks
                = taskController.get(username);
        userTasks.remove(command);
    }
}


public class RequestTask implements Runnable {
    // 保存客户端的sockets
    private LinkedBlockingQueue<Socket> pendingConnections;

    // 用来并行处理用户的请求指令
    private ServerExecutor executor = new ServerExecutor();

    // 保存和任务相关联的Future对象
    private ConcurrentMap<String, ConcurrentMap<ConcurrentCommand, ServerTask<?>>> taskController;

    public RequestTask(LinkedBlockingQueue<Socket>
                               pendingConnections, ConcurrentHashMap<String,
            ConcurrentHashMap<Integer, Future<?>>> taskController) {
        this.pendingConnections = pendingConnections;
        this.taskController = taskController;
    }

    public void run() {
        try {
            while (!Thread.currentThread().interrupted()) {
                try {
                    Socket clientSocket = pendingConnections.take();
                    BufferedReader in = new BufferedReader(new InputStreamReader(clientSocket.getInputStream()));
                    String line = in.readLine();
                    ConcurrentCommand command;

                    String[] commandData = line.split(";");
                    System.out.println("Command: " + commandData[0]);
                    switch (commandData[0]) {
                        case "q":
                            System.out.println("Query");
                            command = new ConcurrentQueryCommand(clientSocket, commandData);
                            break;
                        case "r":
                            System.out.println("Report");
                            command = new ConcurrentReportCommand(clientSocket, commandData);
                            break;
                        case "s":
                            System.out.println("Status");
                            command = new ConcurrentStatusCommand(executor, clientSocket, commandData);
                            break;
                        case "z":
                            System.out.println("Stop");
                            command = new ConcurrentStopCommand(clientSocket, commandData);
                            break;
                        case "c":
                            System.out.println("Cancel");
                            command = new ConcurrentCancelCommand(clientSocket, commandData);
                            break;
                        default:
                            System.out.println("Error");
                            command = new ConcurrentErrorCommand(clientSocket, commandData);
                            break;
                    }
                    ServerTask<?> controller = (ServerTask<?>) executor.submit(command);
                    storeContoller(command.getUsername(), controller, command);

                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        } catch (InterruptedException e) {
            // No Action Required
        }
    }

    // 保存和用户请求关联的Future对象
    private void storeContoller(String userName, ServerTask<?>controller, ConcurrentCommand command) {
        taskController.computeIfAbsent(userName, k -> new ConcurrentHashMap<ConcurrentCommand, ServerTask<?>>()).put(command, controller);
    }

    // 关闭executor
    public void shutdown() {
        String message = "Request Task: "
                + pendingConnections.size()
                + " pending connections.";
        System.out.println(message);
        executor.shutdown();
    }

    // 等待executor执行完所有正在执行的任务
    public void terminate() {
        try {
            executor.awaitTermination(1, TimeUnit.DAYS);
            executor.writeStatistics();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

 

Executor额外的一些方法

你也可以重载以下Executor的方法:

  • shutdown():你必须调用这个方法来结束executor。你可以重载这个方法来释放额外的资源。这个方法等待executor处理完所有排队被执行的任务
  • shutdownNow():和sutdown()方法不同的是,该方法不等待executor处理完正在排队的任务
  • submit(), invokeall(), invokeany():调用这些方法发送并发任务给executor。如果你需要在任务添加到executor里的任务队列之前或之后做一些额外的操作,你可以重载它们。请注意在任务插入队列之前或之后执行的额外操作不同于任务执行前或执行后,如果想要在任务执行前后做一些额外操作,你必须重载beforeExecute()和afterExecute()方法。

ScheduledThreadPoolExecutor类有以下方法允许延迟执行任务,或周期性的任务:

  • schedule():这个方法允许在指定的延迟之后执行任务。任务只被执行一次。
  • scheduleAtFixedRate():这个方法允许在指定的延迟之后周期性地执行任务。它和方法scheduleWithFixedDelay()不同的是:对于scheduleWithFixedDelay()方法,两次执行的时间间隔是上一次执行结束到下一次执行开始之间的时间间隔。对于scheduleAtFixedRate(),两次执行的时间间隔是上一次执行开始到下一次执行开始之间的时间间隔。

 

相关标签: java 多线程