欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

扩展c3p0写的通用数据库操作工具类(使用泛型方法)

程序员文章站 2022-04-14 14:04:29
package com.syx.utils;import com.mchange.v2.c3p0.DataSources;import com.syx.annotation.ID;import com.syx.entity.User;import org.apache.commons.dbutils.QueryRunner;import org.apache.commons.dbutils.handlers.BeanHandler;import org.apache.commons.dbuti...
package com.syx.utils;

import com.mchange.v2.c3p0.DataSources;
import com.syx.annotation.ID;
import com.syx.entity.User;
import org.apache.commons.dbutils.QueryRunner;
import org.apache.commons.dbutils.handlers.BeanHandler;
import org.apache.commons.dbutils.handlers.BeanListHandler;
import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

public class DBUtil {
    //数据源
    private static DataSource dataSource = null;
    //queryRunner对象
    private static QueryRunner queryRunner = null;
    //数据库的配置文件
    private static String dbconfig = "dbconfig.properties";

    /**
     * 初始化数据源
     */
    static {
        try {
            dataSource = DataSources.unpooledDataSource(PropertiesUtil.propertyValue(dbconfig, "jdbc.url"),
                    PropertiesUtil.propertyValue(dbconfig, "jdbc.user"),
                    PropertiesUtil.propertyValue(dbconfig, "jdbc.password"));
        } catch (SQLException e) {
            e.printStackTrace();
        }

    }

    /**
     * queryRunner对象(用户可以自己直接调用queryRunner种的方法)
     *
     * @return
     */
    public static QueryRunner queryRunner() {
        if (queryRunner == null) {
            queryRunner = new QueryRunner(dataSource);
            return queryRunner;
        }
        return queryRunner;
    }


    /**
     * 新增一条记录
     * @param t
     * @param <T>
     * @return
     */
    public static <T> int insert(T t) {
        queryRunner();
        Class<?> clazz = t.getClass();
        //sql的参数
        Object[] params  = getParams(clazz,t);
        //sql
        String sql = insertSql(clazz);

        try {
            return queryRunner.update(sql, params);
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return 0;
    }

    /**
     * 获取get方法的值(用于为insert语句的?赋值)
     * @param clazz
     * @param t
     * @param <T>
     * @return
     */
    private static <T> Object[] getParams(Class<?> clazz,T t) {
        Method[] methods = clazz.getDeclaredMethods();
        ArrayList<Object> ls = new ArrayList<>();
        for (Method method : methods) {
            try {
                if(method.getName().contains("get")){
                    Object param = method.invoke(t);
                    ls.add(param);
                }
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            } catch (InvocationTargetException e) {
                e.printStackTrace();
            }
        }
        return ls.toArray();
    }


    /**
     * 拼接插入sql
     *
     * @param t
     * @param <T>
     * @return
     */
    private static <T> String insertSql(Class<T> t) {
        //表名
        String tableName = t.getSimpleName();
        StringBuffer sql = new StringBuffer("insert into " + tableName + "(");
        Field[] fields = t.getDeclaredFields();
        for (int i = 0; i < fields.length; i++) {
            if (i < fields.length - 1) {
                sql.append(fields[i].getName() + ",");
            } else {
                sql.append(fields[i].getName() + ")");
            }
        }
        sql.append(" values (");
        //添加问号
        for (int i = 0; i < fields.length; i++) {
            if (i < fields.length - 1) {
                sql.append("?,");
            } else {
                sql.append("?)");
            }
        }
        return sql.toString();
    }


    /**
     * 根据id查询一条记录
     *
     * @param t
     * @param id
     * @param <T>
     * @return
     */
    public static <T> T find(Class<T> t, String id) {
        queryRunner();
        //拼接sql
        String sql = querySql(t, false);
        try {
            return queryRunner.query(sql, new BeanHandler<>(t), id);
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return null;
    }


    /**
     * 查询所有
     * @param t
     * @param <T>
     * @return
     */
    public static <T> List<T> findAll(Class<T> t) {
        queryRunner();
        //拼接sql
        String sql = querySql(t, true);
        try {
            return queryRunner.query(sql, new BeanListHandler<>(t));
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 拼接查询sql
     *
     * @param t
     * @param flag
     * @param <T>
     * @return
     */
    private static <T> String querySql(Class<T> t, boolean flag) {
        String tableName = t.getSimpleName();
        //查询所有
        if (flag) {
            String sql = "select * from " + tableName;
            return sql;
        }
        //获取所有的字段
        Field[] fields = t.getDeclaredFields();
        String id = "";
        for (Field field : fields) {
            if (field.isAnnotationPresent(ID.class)) {
                id = field.getName();
            }

        }
        String sql = "select * from " + tableName + " where " + id + " = ?";
        return sql;
    }

    public static void main(String[] args) {
//        List<User> all = findAll(User.class);
//        all.forEach(System.out::println);
//        User user = find(User.class, "1");
//        System.out.println(user);

        User user = new User();

        user.setId(7);
        user.setSex("女");

        user.setUsername("zyz");
        insert(user);
    }
}

本文地址:https://blog.csdn.net/weixin_42304484/article/details/112210858

相关标签: Java工具 mysql