欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

详解Java实现的k-means聚类算法

程序员文章站 2023-12-17 13:15:04
需求 对mysql数据库中某个表的某个字段执行k-means算法,将处理后的数据写入新表中。 源码及驱动 源码 import java.sql.*;...

需求

对mysql数据库中某个表的某个字段执行k-means算法,将处理后的数据写入新表中。

源码及驱动

源码

import java.sql.*;
import java.util.*;

/**
 * @author tianshl
 * @version 2018/1/13 上午11:13
 */
public class kmeans {
  // 源数据
  private list<integer> origins = new arraylist<>();

  // 分组数据
  private map<double, list<integer>> grouped;

  // 初始质心列表
  private list<double> cores;

  // 数据源
  private string tablename;
  private string colname;

  /**
   * 构造方法
   *
   * @param tablename 源数据表名称
   * @param colname  源数据列名称
   * @param cores   质心列表
   */
  private kmeans(string tablename, string colname,list<double> cores){
    this.cores = cores;
    this.tablename = tablename;
    this.colname = colname;
  }

  /**
   * 重新计算质心
   *
   * @return 新的质心列表
   */
  private list<double> newcores(){
    list<double> newcores = new arraylist<>();

    for(list<integer> v: grouped.values()){
      newcores.add(v.stream().reduce(0, (sum, num) -> sum + num) / (v.size() + 0.0));
    }

    collections.sort(newcores);
    return newcores;
  }

  /**
   * 判断是否结束
   *
   * @return bool
   */
  private boolean isover(){
    list<double> _cores = newcores();
    for(int i=0, len=cores.size(); i<len; i++){
      if(!cores.get(i).tostring().equals(_cores.get(i).tostring())){
        // 使用新质心
        cores = _cores;
        return false;
      }
    }
    return true;
  }

  /**
   * 数据分组
   */
  private void setgrouped(){
    grouped = new hashmap<>();

    double core;
    for (integer origin: origins) {
      core = getcore(origin);

      if (!grouped.containskey(core)) {
        grouped.put(core, new arraylist<>());
      }

      grouped.get(core).add(origin);
    }
  }

  /**
   * 选择质心
   *
   * @param num  要分组的数据
   * @return   质心
   */
  private double getcore(integer num){

    // 差 列表
    list<double> diffs = new arraylist<>();

    // 计算差
    for(double core: cores){
      diffs.add(math.abs(num - core));
    }

    // 最小差 -> 索引 -> 对应的质心
    return cores.get(diffs.indexof(collections.min(diffs)));
  }

  /**
   * 建立数据库连接
   * @return connection
   */
  private connection getconn(){
    try {
      // url指向要访问的数据库名mydata
      string url = "jdbc:mysql://localhost:3306/data_analysis_dev";
      // mysql配置时的用户名
      string user = "root";
      // mysql配置时的密码
      string password = "root";

      // 加载驱动
      class.forname("com.mysql.jdbc.driver");

      //声明connection对象
      connection conn = drivermanager.getconnection(url, user, password);

      if(conn.isclosed()){
        system.out.println("连接数据库失败!");
        return null;
      }
      system.out.println("连接数据库成功!");

      return conn;

    } catch (exception e) {
      system.out.println("连接数据库失败!");
      e.printstacktrace();
    }

    return null;
  }

  /**
   * 关闭数据库连接
   *
   * @param conn 连接
   */
  private void close(connection conn){
    try {
      if(conn != null && !conn.isclosed()) conn.close();
    } catch (exception e){
      e.printstacktrace();
    }
  }

  /**
   * 获取源数据
   */
  private void getorigins(){

    connection conn = null;
    try {
      conn = getconn();
      if(conn == null) return;

      statement statement = conn.createstatement();

      resultset rs = statement.executequery(string.format("select %s from %s", colname, tablename));

      while(rs.next()){
        origins.add(rs.getint(1));
      }
      conn.close();
    } catch (exception e){
      e.printstacktrace();
    } finally {
     close(conn);
    }
  }

  /**
   * 向新表中写数据
   */
  private void write(){

    connection conn = null;
    try {
      conn = getconn();
      if(conn == null) return;
      
      // 创建表
      statement statement = conn.createstatement();

      // 删除旧数据表
      statement.execute("drop table if exists k_means; ");
      // 创建新表
      statement.execute("create table if not exists k_means(`core` decimal(11, 7), `col` integer(11));");

      // 禁止自动提交
      conn.setautocommit(false);

      preparedstatement ps = conn.preparestatement("insert into k_means values (?, ?)");

      for(map.entry<double, list<integer>> entry: grouped.entryset()){
        double core = entry.getkey();
        for(integer value: entry.getvalue()){
          ps.setdouble(1, core);
          ps.setint(2, value);
          ps.addbatch();
        }
      }

      // 批量执行
      ps.executebatch();

      // 提交事务
      conn.commit();

      // 关闭连接
      conn.close();
    } catch (exception e){
      e.printstacktrace();
    } finally {
      close(conn);
    }
  }

  /**
   * 处理数据
   */
  private void run(){
    system.out.println("获取源数据");
    // 获取源数据
    getorigins();

    // 停止分组
    boolean isover = false;

    system.out.println("数据分组处理");
    while(!isover) {
      // 数据分组
      setgrouped();
      // 判断是否停止分组
      isover = isover();
    }

    system.out.println("将处理好的数据写入数据库");
    // 将分组数据写入新表
    write();

    system.out.println("写数据完毕");
  }

  public static void main(string[] args){
    list<double> cores = new arraylist<>();
    cores.add(260.0);
    cores.add(600.0);
    // 表名, 列名, 质心列表
    new kmeans("attributes", "attr_length", cores).run();
  }
}

源文件

kmeans.java

编译

javac kmeans.java 

运行

# 指定依赖库
java -djava.ext.dirs=./lib kmeans

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。

上一篇:

下一篇: