通過mybatis插件簡化in查詢

in查詢的痛點

SQL中的in查詢語法如下:


select id, name from user where id in (1, 2, 3)

在使用mybatis做in查詢時。如果是通過xml的方式使用mysql,則需要如下代碼:


<select id="selectByIds" resultMap="userResultMap">

select id, name from user where id in

    <foreach collection="idList" item="id" index="index" open="(" close=")" separator=",">

      #{id}

    </foreach>

</select>

如果是使用注解的方式,則更加難受:

  @Select({
      "<script>",
      " select id, name from user where id in",
      " <foreach collection='ids' item='id' open='(' separator=',' close=')'>",
      "   #{id}",
      " </foreach>",
      "</script>"
  })
  List<User> selectByIds(@Param("idList") List<Integer> idList);

可以看到,不管哪種方式都非常難用,如果工程中有較多的in查詢,會非常痛苦。

優(yōu)化目標(biāo)

我們可以通過自定義mybatis插件的的式來支持in查詢,插件目標(biāo)是將in查詢簡化如下:

  • xml使用方式:
<select id="selectByIds" resultMap="userResultMap">

  select id, name from user where id in (#{idList})

</select>

  • 注解的方式:
  @Select("select id, name from user where id in (#{idList})")
  List<User> selectByIds(@Param("idList") List<Integer> idList);

實現(xiàn)方式

假設(shè)參數(shù)idList包含三個元素[1, 2, 3],通過mybatis插件,對SQL做改寫,對于sql:

select id, name from user where id in (#{idList})

經(jīng)過改寫后:

select id, name from user where id in (#{idList0}, #{idList1}, #{idList2})

mybatis執(zhí)行中的參數(shù)則由Map["idList":idList]改寫為Map["idList0":1, "idList1":2, "idList2":3],則mybatis能正確執(zhí)行in查詢

代碼如下(見:https://github.com/gaohanghbut/stupidmybatis):

package cn.yxffcode.stupidmybatis.core;

import cn.yxffcode.stupidmybatis.commons.Reflections;
import com.google.common.base.Throwables;
import com.google.common.collect.Maps;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Constructor;
import java.util.*;

/**
 * 支持參數(shù)為list的插件實現(xiàn)
 * <p/>
 * 在使用in語句時,往往參數(shù)為集合或者數(shù)組,需要將mapper文件中寫<forEach>,此插件
 * 的目的是為了將forEach從mapper文件中去掉,SQL直接支持in(#{list})
 *
 * @author gaohang on 16/3/3.
 */
@Intercepts({@Signature(type = Executor.class, method = "query",
    args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
    @Signature(type = Executor.class, method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class,
            CacheKey.class, BoundSql.class}), @Signature(type = Executor.class, method = "update",
    args = {MappedStatement.class, Object.class})})
public class ListParameterResolver implements Interceptor {

  private static final Logger LOGGER = LoggerFactory.getLogger(ListParameterResolver.class);

  @Override
  public Object intercept(Invocation invocation) throws Throwable {
    final Object[] args = invocation.getArgs();
    MappedStatement ms = (MappedStatement) args[0];
    Object parameter = args[1];
    BoundSql boundSql = args[args.length - 1] instanceof BoundSql ?
        (BoundSql) args[args.length - 1] :
        ms.getBoundSql(parameter);
    //綁定參數(shù)
    List<ParameterMapping> parameterMappings =
        Collections.unmodifiableList(boundSql.getParameterMappings());
    if (parameterMappings == null || parameterMappings.isEmpty() || parameter == null) {
      return invocation.proceed();
    }
    MetaObject mo = SystemMetaObject.forObject(parameter);

    //先記錄集合參數(shù)的位置
    List<ListParamWrapper> paramPositions =
        buildListParameterPositions(boundSql, parameterMappings, mo);

    if (paramPositions.isEmpty()) {
      return invocation.proceed();
    }

    Map<String, Object> paramMap = normalizeParameters(parameter, parameterMappings, mo);

    rebuildBoundSql(boundSql, parameterMappings, paramPositions, paramMap);

    LOGGER.debug("rebuilded sql:{}", boundSql.getSql());

    args[1] = paramMap;
    args[0] = MappedStatementUtils.copyMappedStatement(ms, boundSql);
    return invocation.proceed();
  }

  @Override
  public Object plugin(Object target) {
    if (target instanceof Executor) {
      return Plugin.wrap(target, this);
    }
    return target;
  }

  @Override
  public void setProperties(Properties properties) {
  }

  private void rebuildBoundSql(BoundSql boundSql, List<ParameterMapping> parameterMappings,
                               List<ListParamWrapper> paramPositions,
                               Map<String, Object> paramMap) {
    String sql = boundSql.getSql();
    int lastAppended = -1;
    StringBuilder newSql = new StringBuilder();

    //copy ParameterMappings, because of the ParameterMapping list is shared by SqlSource, copy a
    // list is commanded.
    List<ParameterMapping> newParameterMappings = new ArrayList<>();

    int parameterAppendIndex = 0;
    for (ListParamWrapper paramPosition : paramPositions) {
      for (int i = parameterAppendIndex; i < paramPosition.mappingPosition; i++) {
        newParameterMappings.add(parameterMappings.get(i));
      }
      //不再需要原始參數(shù),所以第paramPosition.mappingPosition個參數(shù)不需要加入新的參數(shù)列表
      parameterAppendIndex = paramPosition.mappingPosition + 1;
      int sqlPosition = paramPosition.sqlPosition;
      newSql.append(sql.substring(++lastAppended, sqlPosition));
      //拼問號
      int count = 0;
      for (Object obj : paramPosition.params) {
        newSql.append('?').append(',');
        String newProperty = paramPosition.paramName + count;
        paramMap.put(newProperty, obj);

        ParameterMapping npm = copyParameterMappingForNewProperty(
            paramPosition, obj, newProperty);

        newParameterMappings.add(npm);
        count++;
      }
      if (newSql.charAt(newSql.length() - 1) == ',') {
        newSql.deleteCharAt(newSql.length() - 1);
      }
      //因為遍歷iterable的時候已經(jīng)append了?,所以原始SQL中的?應(yīng)該丟棄
      lastAppended = sqlPosition;
    }

    newSql.append(sql.substring(lastAppended + 1));

    for (int i = parameterAppendIndex; i < parameterMappings.size(); i++) {
      newParameterMappings.add(parameterMappings.get(i));
    }
    Reflections.setField(boundSql, "parameterMappings", newParameterMappings);
    Reflections.setField(boundSql, "sql", newSql.toString());
  }

  private Map<String, Object> normalizeParameters(Object parameter,
                                                  List<ParameterMapping> parameterMappings,
                                                  MetaObject mo) {
    Map<String, Object> paramMap = null;
    //注意參數(shù)的處理,參數(shù)可能是map,也可能是普通對象,當(dāng)參數(shù)是普通對象時,需要轉(zhuǎn)換成map
    if (parameter != null && parameter instanceof Map) {
      paramMap = (Map<String, Object>) parameter;
    }
    if (paramMap == null) {
      paramMap = Maps.newHashMap();
      for (ParameterMapping pm : parameterMappings) {
        paramMap.put(pm.getProperty(), mo.getValue(pm.getProperty()));
      }
    }
    return paramMap;
  }

  private ParameterMapping copyParameterMappingForNewProperty(
      ListParamWrapper paramPosition, Object param, String newProperty) {
    try {
      ParameterMapping pm = paramPosition.parameterMapping;
      MetaObject pmmo = SystemMetaObject.forObject(pm);
      Constructor<ParameterMapping> constructor = ParameterMapping.class.getDeclaredConstructor();
      constructor.setAccessible(true);
      ParameterMapping npm = constructor.newInstance();
      MetaObject npmmo = SystemMetaObject.forObject(npm);
      //copy
      Configuration configuration = (Configuration) pmmo.getValue("configuration");
      npmmo.setValue("configuration", configuration);
      npmmo.setValue("property", newProperty);
      npmmo.setValue("mode", pmmo.getValue("mode"));
      npmmo.setValue("javaType", pmmo.getValue("javaType"));
      npmmo.setValue("jdbcType", pmmo.getValue("jdbcType"));
      npmmo.setValue("numericScale", pmmo.getValue("numericScale"));
      Object typeHandler = pmmo.getValue("typeHandler");
      if (typeHandler == null) {
        typeHandler = configuration.getTypeHandlerRegistry().getTypeHandler(param.getClass());
      }
      npmmo.setValue("typeHandler", typeHandler);
      npmmo.setValue("resultMapId", pmmo.getValue("resultMapId"));
      npmmo.setValue("jdbcTypeName", pmmo.getValue("jdbcTypeName"));
      return npm;
    } catch (Exception e) {
      throw Throwables.propagate(e);
    }
  }

  private List<ListParamWrapper> buildListParameterPositions(BoundSql boundSql,
                                                             List<ParameterMapping> parameterMappings,
                                                             MetaObject mo) {
    List<ListParamWrapper> positionMap = new ArrayList<>(2);
    for (int i = 0, j = parameterMappings.size(); i < j; i++) {
      ParameterMapping parameterMapping = parameterMappings.get(i);
      if (!mo.hasGetter(parameterMapping.getProperty())) {
        continue;
      }
      final Object value = mo.getValue(parameterMapping.getProperty());
      if (value == null) {
        continue;
      }
      if (!(value instanceof Iterable) && !value.getClass().isArray()) {
        continue;
      }
      Iterable<?> iterable;
      if (value instanceof Iterable) {
        //找到第i+1個?的位置修改SQL
        iterable = (Iterable<?>) value;
      } else {
        iterable = new AbstractList<Object>() {
          Object[] array = (Object[]) value;

          @Override
          public Object get(int index) {
            return array[index];
          }

          @Override
          public int size() {
            return array.length;
          }
        };
      }

      String sql = boundSql.getSql();
      int paramIndex = 0;
      for (int k = 0, s = sql.length(); k < s; k++) {
        char c = sql.charAt(k);
        if (c == '?') {
          ++paramIndex;
          if (paramIndex == i + 1) {//第i個參數(shù)的位置為k
            positionMap.add(new ListParamWrapper(iterable, k, parameterMapping.getProperty(), i,
                parameterMapping));
          }
        }
      }
    }
    return positionMap;
  }

  private static final class ListParamWrapper {
    private final Iterable<?> params;
    private final int sqlPosition;
    private final String paramName;
    private final int mappingPosition;
    private final ParameterMapping parameterMapping;

    public ListParamWrapper(Iterable<?> params, int sqlPosition, String paramName,
                            int mappingPosition, ParameterMapping parameterMapping) {
      this.params = params;
      this.sqlPosition = sqlPosition;
      this.paramName = paramName;
      this.mappingPosition = mappingPosition;
      this.parameterMapping = parameterMapping;
    }
  }
}

其中使用到的Reflections類:

package cn.yxffcode.stupidmybatis.commons;

import com.google.common.base.Throwables;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;

/**
 * @author gaohang on 15/12/4.
 */
public final class Reflections {
  private Reflections() {
  }

  private static Field findField(Class<?> clazz, String name) {
    return findField(clazz, name, null);
  }

  public static Field findField(Class<?> clazz, String name, Class<?> type) {
    Class<?> searchType = clazz;
    while (!Object.class.equals(searchType) && searchType != null) {
      Field[] fields = searchType.getDeclaredFields();
      for (Field field : fields) {
        if ((name == null || name.equals(field.getName())) && (type == null || type
            .equals(field.getType()))) {
          return field;
        }
      }
      searchType = searchType.getSuperclass();
    }
    return null;
  }

  public static Object getField(String fieldName, Object target) {
    Field field = findField(target.getClass(), fieldName);
    if (!field.isAccessible()) {
      field.setAccessible(true);
    }
    try {
      return field.get(target);
    } catch (IllegalAccessException ex) {
      throw new IllegalStateException("Unexpected reflection exception - " + ex.getClass()
          .getName() + ": " + ex.getMessage(), ex);
    }
  }

  public static void setField(Object target, String fieldName, Object value) {
    try {
      final Field field = target.getClass().getDeclaredField(fieldName);
      if (!field.isAccessible()) {
        field.setAccessible(true);
      }
      field.set(target, value);
    } catch (Exception ex) {
      throw new IllegalStateException("Unexpected reflection exception - " + ex.getClass()
          .getName() + ": " + ex.getMessage(), ex);
    }
  }

  public static <T> T call(Annotation annotation, String methodName) {
    try {
      Method method = annotation.annotationType().getMethod(methodName);
      return (T) method.invoke(annotation);
    } catch (Exception e) {
      throw Throwables.propagate(e);
    }
  }
}

將插件配置到mybatis中:

<plugin interceptor="cn.yxffcode.stupidmybatis.core.ListParameterResolver"></plugin>

如果是springboot:

@Configuration
public class MybatisConfig {

  @Bean
  public ListParameterResolver sqlStatsInterceptor() {
    return new ListParameterResolver();
  }

  @Bean
  public SqlSessionFactoryBean sqlSessionFactory(DataSource dataSource, List<Interceptor> interceptors) {
    final SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean();
    sqlSessionFactoryBean.setDataSource(dataSource);
    if (!CollectionUtils.isEmpty(interceptors)) {
      final Interceptor[] plugins = new Interceptor[interceptors.size()];
      interceptors.toArray(plugins);
      sqlSessionFactoryBean.setPlugins(plugins);
    }
    return sqlSessionFactoryBean;
  }

  @Bean
  public SqlSessionTemplate sqlSessionTemplate(SqlSessionFactory sqlSessionFactory) {
    return new SqlSessionTemplate(sqlSessionFactory);
  }

}

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 1. 簡介 1.1 什么是 MyBatis ? MyBatis 是支持定制化 SQL、存儲過程以及高級映射的優(yōu)秀的...
    笨鳥慢飛閱讀 6,227評論 0 4
  • Mybatis相關(guān) 1.Mybatis是什么? 2.為什么選擇Mybatis? 3、#{}和${}的區(qū)別是什么? ...
    夢殤_fccd閱讀 1,059評論 0 5
  • Mybatis插件介紹 為了提高Mybatis的可擴展能力,Mybatis引入的插件機制,允許開發(fā)人員通過責(zé)任鏈編...
    dayspring閱讀 1,494評論 0 4
  • 1、#{}和${}的區(qū)別是什么?注:這道題是面試官面試我同事的。 答:${}是Properties文件中的變量占位...
    小沙鷹168閱讀 2,272評論 2 64
  • 今天是放假第一天,感恩假期,早上可以睡一個懶覺,偶爾放松一下,太幸福了。感恩一夜好的睡眠,讓身體得到休息。感恩父母...
    那朵花蕾閱讀 201評論 0 1

友情鏈接更多精彩內(nèi)容