前言
在上节中我们分析了Mybatis的源码,但是内容比较多,而且没有那么通俗易懂,所以我自己针对于Mybatis的核心部分纯手写了一遍。
本节纯手写Mybatis意在了解Mybatis的运行机制,结合XML解析,反射和代理等知识完成Mybatis的基本操作,对之前我们讲到的知识点,起到画龙点睛的作用,代码不完善的地方还请谅解啦。
(以下代码纯本人编写,无任何其他博文参考,如需转载,请保留本博客链接或标注来源哦~)
实现原理图
为了方便大家理解,我这里专门针对与我这次手写的Mybatis花了一张流程图,大家可以参照我这个流程图一步步的进行阅读,加强大家的理解~
XML解析为Mapper对象
在这里我们就略过mybatis.config.xml的配置,直奔主题,直接解析我们的UserMapper.xml文件,虽然XML文件看了不少遍,不过也不在乎多这一遍啦~
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.marco.dao.UserMapper">
<resultMap type="com.marco.bean.User" id="userResultMap">
<result column="id" property="id"/>
<result column="user_name" property="userName"/>
<result column="real_name" property="realName"/>
<result column="password" property="password"/>
</resultMap>
<!-- 添加数据 -->
<insert id="addUser" parameterType="Map">
insert into u_user(user_name,password,real_name) values(#{userName},#{password},#{realName})
</insert>
<!-- 查询单个数据 -->
<select id="selectById" resultMap="userResultMap" parameterType="Map">
select * from u_user where id=#{id}
</select>
<!-- 查询所有数据 -->
<select id="selectAll" resultMap="userResultMap">
select * from u_user
</select>
</mapper>
我们稍微弄简单一点,做一个查询和添加的功能。拿到这个XML文件,我们就要分析它的结构。
首先,最外面包着的一层是mapper,mapper有属性namespace,那我们能不能把mapper当作一个类?
我们接着看mapper里面有resultMap,resultMap里是多个result,那么我们可以考虑将result封装成一个类,resultMap包含result对象的List集合,id和type也同为它的属性。
接着就是我们的insert标签和select标签,大家发现这两个标签的结构其实很相似,都有id属性,parameterType属性,和里面的sql语句,除了resultMap,其他属性基本一致,那么我们是不是可以专门把这些信息封装为一个抽象类,然后分别生成实现类Insert和Select实现它,按照源码的命名规范,我们这里把包含单个语句映射关系的类称之为MapppedStatement。
那么理清楚结构后我们就可以开始着手写代码啦!我这里将这里代码全部归类到
com.marco.domain
中
Mapper 类
package com.marco.domain;
import java.util.Map;
public class Mapper {
private String namespace;
private ResultMap resultMap;
private Map<String, MappedStatement> mappedStatements;
public Mapper() {
super();
}
public Mapper(String namespace, ResultMap resultMap, Map<String, MappedStatement> mappedStatements) {
super();
this.namespace = namespace;
this.resultMap = resultMap;
this.mappedStatements = mappedStatements;
}
public Map<String, MappedStatement> getMappedStatement() {
return mappedStatements;
}
public void setMappedStatement(Map<String, MappedStatement> mappedStatements) {
this.mappedStatements = mappedStatements;
}
//通过Id(接口方法名称)对应的mappedStatement(insert或者select对象等)
public MappedStatement getMappedStatement(String key) {
return mappedStatements.get(key);
}
public String getNamespace() {
return namespace;
}
public void setNamespace(String namespace) {
this.namespace = namespace;
}
public ResultMap getResultMap() {
return resultMap;
}
public void setResultMap(ResultMap resultMap) {
this.resultMap = resultMap;
}
@Override
public String toString() {
return "Mapper [namespace=" + namespace + ", resultMap=" + resultMap + ", mappedStatements=" + mappedStatements
+ "]";
}
}
Result类
package com.marco.domain;
public class Result {
private String column;
private String property;
public Result() {
super();
}
public Result(String column, String property) {
super();
this.column = column;
this.property = property;
}
public String getColumn() {
return column;
}
public void setColumn(String column) {
this.column = column;
}
public String getProperty() {
return property;
}
public void setProperty(String property) {
this.property = property;
}
@Override
public String toString() {
return "Result [column=" + column + ", property=" + property + "]";
}
}
ResultMap类
package com.marco.domain;
import java.util.List;
public class ResultMap {
private String id;
private String type;
private List<Result> results;
public ResultMap() {
super();
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public List<Result> getResults() {
return results;
}
public void setResults(List<Result> results) {
this.results = results;
}
@Override
public String toString() {
return "ResultMap [id=" + id + ", type=" + type + ", results=" + results + "]";
}
}
MappedStatement抽象类
package com.marco.domain;
import com.marco.enums.SqlCommandType;
public abstract class MappedStatement {
protected String id;
protected String parameterType;
protected String resultMap;
protected String sql;
protected SqlCommandType type;
public MappedStatement() {
super();
}
public MappedStatement(String id, String parameterType, String resultMap, String sql, SqlCommandType type) {
super();
this.id = id;
this.parameterType = parameterType;
this.resultMap = resultMap;
this.sql = sql;
this.type = type;
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getParameterType() {
return parameterType;
}
public void setParameterType(String parameterType) {
this.parameterType = parameterType;
}
public String getResultMap() {
return resultMap;
}
public void setResultMap(String resultMap) {
this.resultMap = resultMap;
}
public String getSql() {
return sql;
}
public void setSql(String sql) {
this.sql = sql;
}
public SqlCommandType getType() {
return type;
}
public void setType(SqlCommandType type) {
this.type = type;
}
@Override
public String toString() {
return "MappedStatement [id=" + id + ", parameterType=" + parameterType + ", resultMap=" + resultMap + ", sql="
+ sql + ", type=" + type + "]";
}
}
Select类
package com.marco.domain;
import com.marco.enums.SqlCommandType;
public class Select extends MappedStatement{
public Select() {
super();
}
public Select(String id, String parameterType, String resultMap, String sql, SqlCommandType type) {
super(id, parameterType, resultMap, sql, type);
}
}
Insert类
package com.marco.domain;
import com.marco.enums.SqlCommandType;
public class Insert extends MappedStatement{
public Insert() {
super();
}
public Insert(String id, String parameterType, String resultMap, String sql, SqlCommandType type) {
super(id, parameterType, resultMap, sql, type);
}
}
到此为止,基本的关联XML的bean类都已经创建完成了,接下来就是我们解析XML的环节了,这里我就直接上代码了
public class ParseXML {
/**
* @param <T>
* @return 返回解析完成的Mapper.xml的mapper对象
* @throws DocumentException
*/
public <T> Mapper parse() throws DocumentException {
SAXReader reader = new SAXReader();
Mapper mapper = new Mapper();
Document read = reader.read(ParseXML.class.getResourceAsStream("UserMapper.xml"));
Element elements = read.getRootElement();
String namespace = elements.attribute(0).getValue();
mapper.setNamespace(namespace);
Iterator<?> iterator = elements.elementIterator();
Map<String, MappedStatement> mappedStatements = new HashMap<String, MappedStatement>();
while(iterator.hasNext()) {
Element element = (Element) iterator.next();
String command = element.getName();
if("resultMap".equals(command)) {
String id = element.attribute("id").getValue();
String type = element.attribute("type").getValue();
ResultMap resultMap = new ResultMap();
resultMap.setId(id);
resultMap.setType(type);
Iterator<?> iterator2 = element.elementIterator();
List<Result> results = new ArrayList<Result>();//存放Result信息
while (iterator2.hasNext()) {
Element element2 = (Element) iterator2.next();
String column = element2.attribute("column").getValue();
String property = element2.attribute("property").getValue();
Result result = new Result(column, property);
results.add(result);
}
resultMap.setResults(results);//将Result的List集合存放在resultMap中
mapper.setResultMap(resultMap);
} else if(SqlCommandType.INSERT.getType().equals(command)) {
String id = element.attribute("id").getValue();
String parameterType = "";
if(!isBlank(element.attribute("parameterType").getValue())) {
parameterType = element.attribute("parameterType").getValue();
}
String sql = element.getText();
MappedStatement insert = new Insert(id, parameterType, null, sql, SqlCommandType.INSERT);
String key = namespace + "." + id;
mappedStatements.put(key, insert);
} else if(SqlCommandType.SELECT.getType().equals(command)) {
String id = element.attribute("id").getValue();
String parameterType = "";
if(!isBlank(element.attribute("parameterType"))) {
parameterType = element.attribute("parameterType").getValue();
}
String resultMap = "";
if(!isBlank(element.attribute("resultMap"))) {
resultMap = element.attribute("resultMap").getValue();
}
String sql = element.getText();
MappedStatement select = new Select(id, parameterType, resultMap, sql, SqlCommandType.SELECT);
String key = namespace + "." + id;//将namespace和id组合成方法的完全限定名
mappedStatements.put(key, select);//以完全限定名为key,mappedStatement为value存储
}
mapper.setMappedStatement(mappedStatements);
}
System.out.println(mapper);
return mapper;
}
//辅助判断传入的值是否为空
private static boolean isBlank(Object element) {
return (element == null) || "".equals(element);
}
}
大家看了上面的代码应该发现了,我们的mapper就是一个存放XML文件信息的容器,针对于我们的操作映射对象,我选择按照Map的方式以方法的完全限定名为key,mappedStatement对象为值进行存储,这样做的好处在后面我们取mappedStatement对象的时候就能体现出来。
Generator解析Mapper对象
Mapper对象解析完成后,我们的下一步就要想想看,当我们拿到xml里面的"sql语句"和参数,我们该怎么组装,使得sql能够正常执行JDBC操作呢?
首先我们分析我们拿到的xml中的sql,是长这个样子的
我们的目的很明确,就是要将我们传进来的参数替换掉#{xxx},有的朋友一想,欸?很简单,先将这个字符串按照#差分,然后剪裁我们的字符串并将sql语句和参数再次组装成一个字符串像下面这个样子,然后执行不就行啦?
这样做逻辑上确实没有问题,但是遗留下一个很大的漏洞,也就是我们之前反复提到的sql注入的问题。
如果使用字符串拼接,是无法避免sql注入的,那么我们就要换一种思路,考虑使用Preparement预编译的方法来实现值得注入并操作数据库,那么我们得初步想法是,将上面得原始sql拼接成下面这个样子
并且,最重要得一点,是要把我们得传进来得参数,和这些属性匹配上并存放进List中(因为List是一个有序得集合),那么想法有了,就好实现了
package com.marco.util;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Generator {
/**
* @param sql 从xml中查询到的sql语句
* @param params 传入的实参
* @return
*/
public Map<String,Object> execute(String sql, Map<String,Object> params) {
Map<String,Object> paramMap = new HashMap<String, Object>();//存放prepareSql和paramList的返回值的容器
List<Object> paramList = new ArrayList<Object>();//存放prepareSql的参数的List
StringBuilder prepareSql = new StringBuilder("");//新的prepareSql
List<String> sqlStr = new ArrayList<String>();//组装新的prepareSql使用到的List
if(sql.contains("#")) {
String[] splits = sql.split("#");
for (String split : splits) {
Object object = null;
if(split.contains("{")) {
String element = trimStr(split);
object = null;
if(params.get(element) != null) {
object = params.get(element);
} else {
object = params.get("value");
}
split = split.replace("{"+element+"}", "?");
sqlStr.add(split);
} else {
sqlStr.add(0, split);
}
if(object != null) {
paramList.add(object);
}
}
for (int i = 0; i < sqlStr.size(); i++) {
prepareSql.append(sqlStr.get(i));
}
} else {
prepareSql = new StringBuilder(sql);
}
paramMap.put("prepareSql", prepareSql.toString());
paramMap.put("paramList", paramList);
return paramMap;
}
//裁剪sql得辅助方法,用于获取#{}中得属性值
private String trimStr(String str) {
int startIndex = str.indexOf("{");
int endIndex = str.indexOf("}");
return str.substring(startIndex + 1, endIndex);
}
}
重新组装之后,我们获取了类似于这样得sql字符串
insert into u_user(user_name,password,real_name) values(?,?,?)
,以及整合好得参数
['marco','123456','marco zheng']
动态代理MapperProxy解析请求
动态代理MapperProxy可谓是Mybatis的核心,我们能够调用接口的方法,并且能够顺利的执行,取到相应的值,多半要归功于它,之前我们分析过动态代理是在内存中帮我们创建了实现类和实现方法,通过调用接口的方法,间接的调用了这个内存中的实现类的实现方法,从而获取到值,因此,我们回归本质,来看看到底是怎么实现的!
package com.marco.reflect;
import java.io.Serializable;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import com.marco.domain.Session;
public class MapperProxy<T> implements InvocationHandler, Serializable {
private static final long serialVersionUID = 8398948466467961324L;
private Session session;
private final Class<T> mapperInterface;//通过getMapper()获取的接口对象的Class
public MapperProxy(Session session, Class<T> mapperInterface) {
this.session = session;
this.mapperInterface = mapperInterface;
}
@SuppressWarnings("unchecked")
public T newInstance() {//返回当前的代理对象
return (T) Proxy.newProxyInstance(mapperInterface.getClassLoader(), new Class[] {mapperInterface}, this);
}
@SuppressWarnings("unchecked")
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if(Object.class.equals(method.getDeclaringClass())){
return method.invoke(method, args);
}
//MapperMethod是方法执行单元,通过解析方法和配置最终在内存里生成相应的实现方法并调用,可以说是代理的核心
MapperMethod mapperMethod = new MapperMethod(mapperInterface, method,session.getConfig());
if(args == null) {
return mapperMethod.execute(session, null);
}
System.out.println(args.length);
if(args.length == 1) {
Class<? extends Object> clz = args[0].getClass();
if(clz == Map.class || clz == HashMap.class || clz == TreeMap.class) {
return mapperMethod.execute(session, (Map<String, Object>)args[0]);
} else if(clz == Integer.class) {
Map<String, Object> map = new HashMap<String, Object>();
map.put("value", args[0]);//主要是为了传输的参数只有一个值的时候而定义
return mapperMethod.execute(session, map);
}
}
return new RuntimeException("the method " + method + "was binded with error");
}
}
通过上面的代码我们可以发现又一个核心点MapperMethod!我们通过代理拿到了接口被调用的方法名,和session中的config(这个我们后面会讲到,实质上就是获取到了Mapper对象)
在这里我偷懒了,因为只写了两个方法insert和select,因此我对参数的解析没有那么详细,希望大家理解吧~
解析参数之后,我们就执行了MapperMethod中的execute方法,我们一探究竟,这execute方法里面的内容是什么
代理核心单元MapperMethod
MapperMethod的execute首先是对我们接口被调用方法的解析,判断他是属于什么类型,比方说是插入还是删除还是查询?这个我们在当时解析XML映射文件成Mapper的时候已经将Type封装到对应的MappedStatement中了,所以根据方法的完全限定名是可以查的到MappedStatement的,那么这个查询对应Type和方法id的单元被我们封装到内部类SqlCommand中,在上一节源码解析 Marco’s Java【Mybatis进阶(五) 漫谈Mybatis动态代理及源码解析】 时我们有特别讲到过这个内部类。
package com.marco.reflect;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import com.marco.domain.MappedStatement;
import com.marco.domain.Mapper;
import com.marco.domain.Session;
import com.marco.enums.ReturnType;
import com.marco.enums.SqlCommandType;
public class MapperMethod {
private final SqlCommand command;
private Method method;
//构造函数
public <T> MapperMethod(Class<T> mapperInterface, Method method, Mapper mapper) {
this.command = new SqlCommand(mapperInterface, method, mapper);//通过静态内部类封装type和name
this.method = method;
}
public <T> Object execute(Session session,Map<String, Object> params) {
//定义返回结果
Object result = null;
if(SqlCommandType.INSERT == command.getType()) {
boolean flag = session.insert(command.getKey(), params);
result = flag;
} else if(SqlCommandType.SELECT == command.getType()) {
String returnType = this.method.getReturnType().getSimpleName();
if(returnType.equals(ReturnType.LIST.getValue())) {
List<T> selectList = session.selectList(command.getKey(), params);
result = selectList;
} else {
Object selectOne = session.selectOne(command.getKey(), params);
result = selectOne;
}
}
if(result == null) {
throw new RuntimeException("Mapper method '" + command.getName()
+ " attempted to return null from a method with a primitive return type (" + method.getReturnType() + ")");
}
return result;
}
//SqlCommand是一个静态内部类,用于封装mappedStatement的type(执行sql的类型,如insert或者select)和name(也就是XML中mappedStatement元素的id)
public static class SqlCommand {
//xml标签的id
private final String name;
//insert update delete select的具体类型
private final SqlCommandType type;
private final String key;
public SqlCommand(Class<?> mapperInterface, Method method, Mapper mapper) {
String methodName = method.getName();//获取接口调用的方法
String interfaceName = mapperInterface.getName();//获取接口的名称
String key = interfaceName + "." + methodName;//组装完全限定名
MappedStatement mappedStatement = mapper.getMappedStatement(key);//通过完全限定名获取当前的MappedStatement
if(mappedStatement == null) {
throw new RuntimeException("the key :" + key +" doesn't exit");
}
name = mappedStatement.getId();
type = mappedStatement.getType();
this.key = key;
}
public String getName() {
return name;
}
public String getKey() {
return key;
}
public SqlCommandType getType() {
return type;
}
}
}
最终解析完所有的控制台传过来的Class和解析XML映射配置文件的内容,我们已经完全明确了下面几样信息
1)要执行什么操作
2)返回值是什么
3)传入的参数是什么
接下来就该我们的Session大显神威啦
Session
或许大家对Session的印象是对深刻的,因为从我们刚开始接触Mybatis,到使用它,第一个接触到的就是Session会话了,还记得我们刚开始没有使用动态代理时是怎么调用方法执行JDBC操作的吗?
是不是通过
sqlSession.selectOne("com.marco.UserDao.selectById","85")
执行的?
实际上我们解析了XML和被调用方法后,执行的也是这个玩意儿~
那接下来看看Session都帮我们做了什么吧~
package com.marco.domain;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.dom4j.DocumentException;
import com.marco.reflect.MapperProxy;
import com.marco.util.ConnPool;
import com.marco.util.Generator;
import com.marco.util.JDBCUtil;
import com.marco.util.ParseXML;
public class Session {
private static ParseXML parseXML;
private static Mapper mapper;
private Generator generator;
private ConnPool connPool = new ConnPool();
private Connection conn = null;
public Session() {
//通过连接池获取连接对象,连接池之前的项目中有讲到多次,这里就不放代码啦
conn = connPool.getConnection();
try {
conn.setAutoCommit(false);//设置默认提交为false
} catch (SQLException e) {
e.printStackTrace();
}
}
public <T> T getMapper(Class<T> paramClass) {
MapperProxy<T> mapperProxy = new MapperProxy<T>(this, paramClass);
return mapperProxy.newInstance();
}
static {
try {
parseXML = new ParseXML();//获取ParseXML对象
mapper = parseXML.parse();//初始化Mapper
} catch (DocumentException e) {
e.printStackTrace();
}
}
public Mapper getConfig() {
return mapper;//返回Mapper.xml文件所有配置和信息
}
//默认设值的提交方式为手动提交,当提交后,将force参数置为null,从而强制提交
public void commit() {
commit(false);
}
private void commit(boolean force) {
try {
if(!force) {
conn.commit();
}
} catch (SQLException e) {
e.printStackTrace();
}
}
//关闭sesson资源
public void close() {
try {
conn.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
//回滚事务
public void rollback() {
try {
conn.rollback();
} catch (SQLException e) {
e.printStackTrace();
}
}
/**
* @param paramString mapper中的key值,也就是方法的完全限定名
* @param params 传入的实参
* @return 返回boolean类型的值 true or false
*/
@SuppressWarnings("unchecked")
public boolean insert(String paramString, Map<String, Object> params) {
Boolean flag = false;
generator = new Generator();//获取sql代码生成器
Map<String, MappedStatement> mappedStatements = new HashMap<String, MappedStatement>();
if(mapper.getMappedStatement() != null) {
mappedStatements = mapper.getMappedStatement();//获取XML中的映射对象,如insert对象,select对象等
}
if(mappedStatements.containsKey(paramString)) {
String sql = mappedStatements.get(paramString).getSql();
Map<String, Object> execute = generator.execute(sql, params);//执行代码生成器,生成preparement SQL和处理过的传入的参数
String prepareSql = (String)execute.get("prepareSql");
List<Object> paramList = (List<Object>) execute.get("paramList");
try {
flag = JDBCUtil.add(conn, prepareSql, paramList);//使用通用添加,添加我们传入的元素
} catch (SQLException e) {
e.printStackTrace();
}
}
return flag;
}
/**
* @param <T>
* @param paramString mapper中的key值,也就是方法的完全限定名
* @param params 传入的实参
* @return 返回一个对象
*/
public <T> T selectOne(String paramString, Map<String, Object> params) {
List<T> list = selectList(paramString, params);
return list.size() > 0? list.get(0) : null;
}
/**
* @param <T>
* @param paramString mapper中的key值,也就是方法的完全限定名
* @param params 传入的实参
* @return 返回对象的一个List集合
*/
@SuppressWarnings("unchecked")
public <T> List<T> selectList(String paramString, Map<String, Object> params) {
generator = new Generator();//获取sql代码生成器
Map<String, MappedStatement> mappedStatements = new HashMap<String, MappedStatement>();
if(mapper.getMappedStatement() != null) {
mappedStatements = mapper.getMappedStatement();//获取XML中的映射对象,如insert对象,select对象等
}
if(mappedStatements.containsKey(paramString)) {
String sql = mappedStatements.get(paramString).getSql();
Map<String, Object> execute = generator.execute(sql, params);
String prepareSql = (String)execute.get("prepareSql");
List<Object> paramList = (List<Object>) execute.get("paramList");
ResultMap resultMap = mapper.getResultMap();
String type = resultMap.getType();//根据resultMap获取返回值的完全限定名
Class<?> clz;
try {
clz = Class.forName(type);//通过反射获取返回值对象的Class
List<T> objList = (List<T>) JDBCUtil.query(conn, clz, prepareSql, resultMap, paramList);//调用JDBCUtil的query方法
return objList;
} catch (ClassNotFoundException e) {
System.err.println("Class:" + type + " not found,Please reconfirm again");
e.printStackTrace();
}
}
throw new RuntimeException("No such key: " + paramString + " was found");
}
}
我们发现Session从MappedMethod中拿到了方法的完全限定名,然后去Mapper这个容器中去找,找到对应的执行sql,以及参数,然后放入到generator这个代码和参数的生成器中,最后使用我们的通用查询和添加完成了最后的收尾操作!
代码测试
那么接下来测试看看吧
public class Test {
public static void main(String[] args) {
Session session = new Session();
Map<String, Object> params = new HashMap<String, Object>();
UserMapper mapper = session.getMapper(UserMapper.class);
List<User> selectAll = mapper.selectAll();
System.out.println(selectAll);
params.put("userName", "tina");
params.put("password", "123456");
params.put("realName", "tina zhang");
mapper.addUser(params);
session.commit();
session.close();
}
}