[TOC]

sharding-jdbc架构

架构图
image

核心概念

LogicTable
数据分片的逻辑表,对于水平拆分的数据库(表),同一类表的总称。例:订单数据根据主键尾数拆分为10张表,分别是t_order_0到t_order_9,他们的逻辑表名为t_order。

ActualTable
在分片的数据库中真实存在的物理表。即上个示例中的t_order_0到t_order_9。

DataNode
数据分片的最小单元。由数据源名称和数据表组成,例:ds_1.t_order_0。配置时默认各个分片数据库的表结构均相同,直接配置逻辑表和真实表对应关系即可。如果各数据库的表结果不同,可使用ds.actual_table配置。

BindingTable
指在任何场景下分片规则均一致的主表和子表。例:订单表和订单项表,均按照订单ID分片,则此两张表互为BindingTable关系。BindingTable关系的多表关联查询不会出现笛卡尔积关联,关联查询效率将大大提升。举例说明,如果SQL为:

1
SELECT i.* FROM t_order o JOIN t_order_item i ON o.order_id=i.order_id WHERE o.order_id in (10, 11);

在不配置绑定表关系时,假设分片键order_id将数值10路由至第0片,将数值11路由至第1片,那么路由后的SQL应该为4条,它们呈现为笛卡尔积:
1
2
3
4
5
6
7
SELECT i.* FROM t_order_0 o JOIN t_order_item_0 i ON o.order_id=i.order_id WHERE o.order_id in (10, 11);

SELECT i.* FROM t_order_0 o JOIN t_order_item_1 i ON o.order_id=i.order_id WHERE o.order_id in (10, 11);

SELECT i.* FROM t_order_1 o JOIN t_order_item_0 i ON o.order_id=i.order_id WHERE o.order_id in (10, 11);

SELECT i.* FROM t_order_1 o JOIN t_order_item_1 i ON o.order_id=i.order_id WHERE o.order_id in (10, 11);
在配置绑定表关系后,路由的SQL应该为2条:
1
2
3
SELECT i.* FROM t_order_0 o JOIN t_order_item_0 i ON o.order_id=i.order_id WHERE o.order_id in (10, 11);

SELECT i.* FROM t_order_1 o JOIN t_order_item_1 i ON o.order_id=i.order_id WHERE o.order_id in (10, 11);
其中t_order在FROM的最左侧,Sharding-Sphere将会以它作为整个绑定表的主表。所有路由计算将会只使用主表的策略,那么t_order_item表的分片计算将会使用t_order的条件。故绑定表之间的分区键要完全相同。

ShardingColumn
分片字段。用于将数据库(表)水平拆分的关键字段。例:订单表订单ID分片尾数取模分片,则订单ID为分片字段。SQL中如果无分片字段,将执行全路由,性能较差。Sharding-JDBC支持多分片字段。

ShardingAlgorithm
分片算法。Sharding-JDBC通过分片算法将数据分片,支持通过等号、BETWEEN和IN分片。分片算法目前需要业务方开发者自行实现,可实现的灵活度非常高。未来Sharding-JDBC也将会实现常用分片算法,如range,hash和tag等。

SQL Hint
对于分片字段非SQL决定,而由其他外置条件决定的场景,可使用SQL Hint灵活的注入分片字段。例:内部系统,按照员工登录ID分库,而数据库中并无此字段。SQL Hint支持通过ThreadLocal和SQL注释(待实现)两种方式使用。

Config Map
通过ConfigMap可以配置分库分表或读写分离数据源的元数据,可通过调用ConfigMapContext.getInstance()获取ConfigMap中的shardingConfig和masterSlaveConfig数据。例:如果机器权重不同则流量可能不同,可通过ConfigMap配置机器权重元数据。

LogicIndex
数据分片的逻辑索引名称,DDL语句中水平拆分的表,同一类表的总称。例:订单数据根据主键尾数拆分为10张表,分别是t_order_0到t_order_9,他们的逻辑表名为t_order,对于DROP INDEX t_order_index语句, 需在TableRule中配置逻辑索引t_order_index。

配置

配置sharding-jdbc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
@Bean(name = "shardingDataSource")
DataSource getShardingDataSource() throws SQLException {

Map<String, DataSource> dataSourceMap = new HashMap<>();
/**
* 将dataSource纳入shardingDataSource管理
*/
dataSourceMap.put("souche_bilocation", dataSource);
ShardingRuleConfiguration shardingRuleConfig = new ShardingRuleConfiguration();
/**
* 定义分表算法
*/
TableRuleConfiguration tableRuleConfig = new TableRuleConfiguration();
/**
* 设置逻辑表(逻辑表就是我们写在SQL中的表名)
*/
tableRuleConfig.setLogicTable("customer_track");
/**
* 设置实际表(最后我们数据落地的表,这里就是cutomer_track_0,customer_track_1)
*/
tableRuleConfig.setActualDataNodes("souche_bilocation.customer_track_${0..1}");
/**
* 设置分表配置,这里分表字段是user_id,算法就是实现的hashcode取余算法
*/
tableRuleConfig.setTableShardingStrategyConfig(
new StandardShardingStrategyConfiguration("user_id", StringModeShardingAlgorithm.class.getName()));
shardingRuleConfig.getTableRuleConfigs().add(tableRuleConfig);

return ShardingDataSourceFactory.createDataSource(dataSourceMap, shardingRuleConfig,new ConcurrentHashMap<String, Object>(), new Properties());
}

执行

前提

本文内容基于以下前提:

  1. shardingjdbc版本为2.0.3;

    1
    2
    3
    4
    5
    <dependency>
    <groupId>io.shardingjdbc</groupId>
    <artifactId>sharding-jdbc-core</artifactId>
    <version>2.0.3</version>
    </dependency>
  2. 流程以preparedStatement.executeQuery()方式执行;

    Statement.executeQuery()依赖的是StatementRoutingEngine;

    preparedStatement.executeQuery()依赖的是PreparedStatementRoutingEngine;

    StatementRoutingEngine和PreparedStatementRoutingEngine内部都是依赖ParsingSQLRouter进行解析和路由的,所以整体的执行过程是一致的。

sql解析

添加字段

appendDerivedColumns()和appendDerivedOrderBy()方法分别处理avg函数字段和排序字段的添加。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public abstract class AbstractSelectParser implements SQLParser {

@Override
public final SelectStatement parse() {
SelectStatement result = parseInternal();
if (result.containsSubQuery()) {
result = result.mergeSubQueryStatement();
}
// TODO move to rewrite
appendDerivedColumns(result);
appendDerivedOrderBy(result);
return result;
}
}

avg函数添加字段及group字段补充(appendDerivedColumns)

sql中有avg()函数时,想要正确的从几个库表中获取数据,几个avg()求和作平均得到的结果是不对的,必须采用sum(field)/count(field),所以要将原来的select avg(field) from改写为select sum(field), count(field)

这个阶段发生在MySQLSelectParser.parse()(实际是AbstractSelectParser.parse())。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
private void appendDerivedColumns(final SelectStatement selectStatement) {
ItemsToken itemsToken = new ItemsToken(selectStatement.getSelectListLastPosition());
// avg(field)函数添加sum(field)和count(field)
appendAvgDerivedColumns(itemsToken, selectStatement);
//order by字段/group by字段如果不在select列表中就添加,保证内存排序可行
appendDerivedOrderColumns(itemsToken, selectStatement.getOrderByItems(), ORDER_BY_DERIVED_ALIAS, selectStatement);
appendDerivedOrderColumns(itemsToken, selectStatement.getGroupByItems(), GROUP_BY_DERIVED_ALIAS, selectStatement);
// 下面把新加的字段也加入SqlToken,SqlToken是用于生成最后执行的sql,SQLRewriteEngine.rewrite()
if (!itemsToken.getItems().isEmpty()) {
selectStatement.getSqlTokens().add(itemsToken);
}
}

private void appendAvgDerivedColumns(final ItemsToken itemsToken, final SelectStatement selectStatement) {
// 偏移量防止别名冲突
int derivedColumnOffset = 0;
for (SelectItem each : selectStatement.getItems()) {
// 找到avg函数
if (!(each instanceof AggregationSelectItem) || AggregationType.AVG != ((AggregationSelectItem) each).getType()) {
continue;
}
AggregationSelectItem avgItem = (AggregationSelectItem) each;
// count的别名
String countAlias = String.format(DERIVED_COUNT_ALIAS, derivedColumnOffset);
// 生成Item
AggregationSelectItem countItem = new AggregationSelectItem(AggregationType.COUNT, avgItem.getInnerExpression(), Optional.of(countAlias));
// sum的别名
String sumAlias = String.format(DERIVED_SUM_ALIAS, derivedColumnOffset);
// 生成Item
AggregationSelectItem sumItem = new AggregationSelectItem(AggregationType.SUM, avgItem.getInnerExpression(), Optional.of(sumAlias));
// count和sum加到avg的derivedAggregationSelectItems
avgItem.getDerivedAggregationSelectItems().add(countItem);
avgItem.getDerivedAggregationSelectItems().add(sumItem);
// TODO replace avg to constant, avoid calculate useless avg
// 这里加到itemsToken用于后续SQL改写
itemsToken.getItems().add(countItem.getExpression() + " AS " + countAlias + " ");
itemsToken.getItems().add(sumItem.getExpression() + " AS " + sumAlias + " ");
derivedColumnOffset++;
}
}
// order by字段/group by字段如果不在select列表中就添加,保证内存排序可行
private void appendDerivedOrderColumns(final ItemsToken itemsToken, final List<OrderItem> orderItems, final String aliasPattern, final SelectStatement selectStatement) {
int derivedColumnOffset = 0;
for (OrderItem each : orderItems) {
if (!isContainsItem(each, selectStatement)) {
String alias = String.format(aliasPattern, derivedColumnOffset++);
each.setAlias(Optional.of(alias));
itemsToken.getItems().add(each.getQualifiedName().get() + " AS " + alias + " ");
}
}
}

有group无order添加排序(appendDerivedOrderBy)

如果有group by但是没有order by,那就加上group by的字段为排序字段。

1
2
3
4
5
6
private void appendDerivedOrderBy(final SelectStatement selectStatement) {
if (!selectStatement.getGroupByItems().isEmpty() && selectStatement.getOrderByItems().isEmpty()) {
selectStatement.getOrderByItems().addAll(selectStatement.getGroupByItems());
selectStatement.getSqlTokens().add(new OrderByToken(selectStatement.getGroupByLastPosition()));
}
}

路由

这里的路由是指根据逻辑表名找到实际数据库和物理表名的过程。

路由发生在 ParsingSQLRouter.route() -> RoutingEngine.route()。

ParsingSQLRouter.route()中根据sql是DDL/无表/单表/多表,使用不同的RoutingEngine。
路由.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
private RoutingResult route(final List<Object> parameters, final SQLStatement sqlStatement) {
Collection<String> tableNames = sqlStatement.getTables().getTableNames();
RoutingEngine routingEngine;
if (sqlStatement instanceof DDLStatement) {
routingEngine = new DDLRoutingEngine(shardingRule, parameters, (DDLStatement) sqlStatement);
} else if (tableNames.isEmpty()) {
routingEngine = new DatabaseAllRoutingEngine(shardingRule.getDataSourceMap());
} else if (1 == tableNames.size() || shardingRule.isAllBindingTables(tableNames) || shardingRule.isAllInDefaultDataSource(tableNames)) {
routingEngine = new SimpleRoutingEngine(shardingRule, parameters, tableNames.iterator().next(), sqlStatement);
} else {
// TODO config for cartesian set
routingEngine = new ComplexRoutingEngine(shardingRule, parameters, tableNames, sqlStatement);
}
return routingEngine.route();
}

路由结果(RoutingResult)会在ParsingSQLRouter.route()中用于获取TableUnit。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@Override
public SQLRouteResult route(final String logicSQL, final List<Object> parameters, final SQLStatement sqlStatement) {
// ...省略无关代码

// route()执行了上文的路由过程,获得CartesianRoutingResult
RoutingResult routingResult = route(parameters, sqlStatement);
/**
* 这里CartesianRoutingResult单独处理
**/
if (routingResult instanceof CartesianRoutingResult) {
for (CartesianDataSource cartesianDataSource : ((CartesianRoutingResult) routingResult).getRoutingDataSources()) {
// cartesianTableReference中包含了一个db中的所有实际表名
for (CartesianTableReference cartesianTableReference : cartesianDataSource.getRoutingTableReferences()) {
result.getExecutionUnits().add(new SQLExecutionUnit(cartesianDataSource.getDataSource(), rewriteEngine.generateSQL(cartesianTableReference, sqlBuilder)));
}
}
} else { // 通用处理
for (TableUnit each : routingResult.getTableUnits().getTableUnits()) {
result.getExecutionUnits().add(new SQLExecutionUnit(each.getDataSourceName(), rewriteEngine.generateSQL(each, sqlBuilder)));
}
}
if (showSQL) {
SQLLogger.logSQL(logicSQL, sqlStatement, result.getExecutionUnits(), parameters);
}
return result;
}

单表路由(SimpleRoutingEngine)

SimpleRoutingEngine是所有路由的核心,使用配置的TableRule进行路由。

路由解析:

  1. 获取db和table路由字段值,如果是强制路由则从HintManagerHolder获取;
  2. 根据StandardShardingStrategyConfiguration配置的ShardingStrategy进行db路由;
  3. 根据StandardShardingStrategyConfiguration配置的ShardingStrategy进行table路由;
  4. 路由结果封装。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
@RequiredArgsConstructor
public final class SimpleRoutingEngine implements RoutingEngine {

private final ShardingRule shardingRule;

private final List<Object> parameters;

private final String logicTableName;

private final SQLStatement sqlStatement;

@Override
public RoutingResult route() {
// 根据逻辑表名获取配置的规则
TableRule tableRule = shardingRule.getTableRule(logicTableName);
// 获取配置的路由字段值
List<ShardingValue> databaseShardingValues = getDatabaseShardingValues(tableRule);
// 获取配置的路由字段值
List<ShardingValue> tableShardingValues = getTableShardingValues(tableRule);
// 根据配置(ShardingStrategy)进行db路由
Collection<String> routedDataSources = routeDataSources(tableRule, databaseShardingValues);
Collection<DataNode> routedDataNodes = new LinkedList<>();
for (String each : routedDataSources) {
// 根据路由字段值进行table路由
routedDataNodes.addAll(routeTables(tableRule, each, tableShardingValues));
}
return generateRoutingResult(routedDataNodes);
}

private List<ShardingValue> getDatabaseShardingValues(final TableRule tableRule) {
// strategy 配置的规则
ShardingStrategy strategy = shardingRule.getDatabaseShardingStrategy(tableRule);
// isUseShardingHint 先判断是否使用强制路由
return HintManagerHolder.isUseShardingHint() ? getDatabaseShardingValuesFromHint(strategy.getShardingColumns()) : getShardingValues(strategy.getShardingColumns());
}

private List<ShardingValue> getTableShardingValues(final TableRule tableRule) {
// strategy 配置的规则
ShardingStrategy strategy = shardingRule.getTableShardingStrategy(tableRule);
// isUseShardingHint 先判断是否使用强制路由
return HintManagerHolder.isUseShardingHint() ? getTableShardingValuesFromHint(strategy.getShardingColumns()) : getShardingValues(strategy.getShardingColumns());
}

private List<ShardingValue> getDatabaseShardingValuesFromHint(final Collection<String> shardingColumns) {
List<ShardingValue> result = new ArrayList<>(shardingColumns.size());
for (String each : shardingColumns) {
Optional<ShardingValue> shardingValue = HintManagerHolder.getDatabaseShardingValue(new ShardingKey(logicTableName, each));
if (shardingValue.isPresent()) {
result.add(shardingValue.get());
}
}
return result;
}

private List<ShardingValue> getTableShardingValuesFromHint(final Collection<String> shardingColumns) {
List<ShardingValue> result = new ArrayList<>(shardingColumns.size());
for (String each : shardingColumns) {
Optional<ShardingValue> shardingValue = HintManagerHolder.getTableShardingValue(new ShardingKey(logicTableName, each));
if (shardingValue.isPresent()) {
result.add(shardingValue.get());
}
}
return result;
}

private List<ShardingValue> getShardingValues(final Collection<String> shardingColumns) {
List<ShardingValue> result = new ArrayList<>(shardingColumns.size());
for (String each : shardingColumns) {
Optional<Condition> condition = sqlStatement.getConditions().find(new Column(each, logicTableName));
if (condition.isPresent()) {
result.add(condition.get().getShardingValue(parameters));
}
}
return result;
}

private Collection<String> routeDataSources(final TableRule tableRule, final List<ShardingValue> databaseShardingValues) {
Collection<String> availableTargetDatabases = tableRule.getActualDatasourceNames();
if (databaseShardingValues.isEmpty()) {
return availableTargetDatabases;
}
// doSharding 进行路由判断
Collection<String> result = shardingRule.getDatabaseShardingStrategy(tableRule).doSharding(availableTargetDatabases, databaseShardingValues);
Preconditions.checkState(!result.isEmpty(), "no database route info");
return result;
}

private Collection<DataNode> routeTables(final TableRule tableRule, final String routedDataSource, final List<ShardingValue> tableShardingValues) {
Collection<String> availableTargetTables = tableRule.getActualTableNames(routedDataSource);
// doSharding 进行路由判断
Collection<String> routedTables = tableShardingValues.isEmpty() ? availableTargetTables
: shardingRule.getTableShardingStrategy(tableRule).doSharding(availableTargetTables, tableShardingValues);
Preconditions.checkState(!routedTables.isEmpty(), "no table route info");
Collection<DataNode> result = new LinkedList<>();
for (String each : routedTables) {
result.add(new DataNode(routedDataSource, each));
}
return result;
}
// 封装RoutingResult
private RoutingResult generateRoutingResult(final Collection<DataNode> routedDataNodes) {
RoutingResult result = new RoutingResult();
for (DataNode each : routedDataNodes) {
result.getTableUnits().getTableUnits().add(new TableUnit(each.getDataSourceName(), logicTableName, each.getTableName()));
}
return result;
}
}

DDL路由(DDLRoutingEngine)

DDL路由实际使用的是SimpleRoutingEngine,和单表路由一致,详见单表路由。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public final class DDLRoutingEngine implements RoutingEngine {
private final ShardingRule shardingRule;

private final List<Object> parameters;

private final DDLStatement ddlStatement;

@Override
public RoutingResult route() {
// 使用SimpleRoutingEngine
return new SimpleRoutingEngine(shardingRule, parameters, getLogicTableName(), ddlStatement).route();
}
private String getLogicTableName() {
if (ddlStatement.getTables().isEmpty()) {
return shardingRule.getLogicTableName(getIndexToken().getIndexName());
}
return ddlStatement.getTables().getSingleTableName();
}
private IndexToken getIndexToken() {
Preconditions.checkState(1 == ddlStatement.getSqlTokens().size());
return (IndexToken) ddlStatement.getSqlTokens().get(0);
}
}

无表路由(DatabaseAllRoutingEngine)

无表路由返回值封装了空表new TableUnit(each, "", ""),其中each代表各个db的datasource。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
@RequiredArgsConstructor
public final class DatabaseAllRoutingEngine implements RoutingEngine {

private final Map<String, DataSource> dataSourceMap;

@Override
public RoutingResult route() {
RoutingResult result = new RoutingResult();
for (String each : dataSourceMap.keySet()) {
result.getTableUnits().getTableUnits().add(new TableUnit(each, "", ""));
}
return result;
}
}

多表路由(ComplexRoutingEngine)

绑定表

绑定表

配置绑定表

ShardingDataSource配置shardingRuleConfig.setBindingTableGroups()即可绑定表。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@Bean(name = "shardingDataSource")
DataSource getShardingDataSource() throws SQLException {
Map<String, DataSource> dataSourceMap = new HashMap<>();
dataSourceMap.put("souche_bilocation_0", dataSource);
dataSourceMap.put("souche_bilocation_1", dataSource);
ShardingRuleConfiguration shardingRuleConfig = new ShardingRuleConfiguration();

shardingRuleConfig.setDefaultDatabaseShardingStrategyConfig(new StandardShardingStrategyConfiguration("id", StringModeShardingAlgorithm.class.getName()));

TableRuleConfiguration tableRuleConfig = new TableRuleConfiguration();
tableRuleConfig.setLogicTable("customer_track");
tableRuleConfig.setActualDataNodes("souche_bilocation_${0..1}.customer_track_${0..1}");
tableRuleConfig.setTableShardingStrategyConfig(new StandardShardingStrategyConfiguration("id", StringModeShardingAlgorithm.class.getName()));
shardingRuleConfig.getTableRuleConfigs().add(tableRuleConfig);

TableRuleConfiguration customer_track_detail = new TableRuleConfiguration();
customer_track_detail.setLogicTable("customer_track_detail");
customer_track_detail.setActualDataNodes("souche_bilocation_${0..1}.customer_track_detail_${0..1}");
customer_track_detail.setTableShardingStrategyConfig(new StandardShardingStrategyConfiguration("track_id", StringModeShardingAlgorithm.class.getName()));
shardingRuleConfig.getTableRuleConfigs().add(customer_track_detail);
// 绑定表
shardingRuleConfig.setBindingTableGroups(Arrays.asList("customer_track,customer_track_detail"));

return ShardingDataSourceFactory.createDataSource(dataSourceMap, shardingRuleConfig,new ConcurrentHashMap<String, Object>(), new Properties());
}

在ShardingRule构造器中会将bindingTableGroups转换成BindingTableRules。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public ShardingRule(final Map<String, DataSource> dataSourceMap, final String defaultDataSourceName, final Collection<TableRule> tableRules, 
final Collection<String> bindingTableGroups, final ShardingStrategy defaultDatabaseShardingStrategy,
final ShardingStrategy defaultTableShardingStrategy, final KeyGenerator defaultKeyGenerator) {
// ... 省略其他
for (String group : bindingTableGroups) {
List<TableRule> tableRulesForBinding = new LinkedList<>();
// 以,分割
for (String logicTableNameForBindingTable : StringUtil.splitWithComma(group)) {
tableRulesForBinding.add(getTableRule(logicTableNameForBindingTable));
}
// 绑定表的rule存在BindingTableRule中
this.bindingTableRules.add(new BindingTableRule(tableRulesForBinding));
}
// ... 省略其他
}

BindingTableRule

BindingTableRule中保存了绑定表的TableRule集合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@RequiredArgsConstructor
@Getter
public final class BindingTableRule {

private final List<TableRule> tableRules;

/**
* Adjust contains this logic table in this rule.
*
* @param logicTableName logic table name
* @return contains this logic table or not
*/
public boolean hasLogicTable(final String logicTableName) {
for (TableRule each : tableRules) {
if (each.getLogicTable().equals(logicTableName.toLowerCase())) {
return true;
}
}
return false;
}

/**
* Deduce actual table name from other actual table name in same binding table rule.
* 从同一绑定表规则中的其他实际表名中推导出该逻辑表的实际表名。
* @param dataSource data source name
* @param logicTable logic table name
* @param otherActualTable other actual table name in same binding table rule
* @return actual table name
*/
public String getBindingActualTable(final String dataSource, final String logicTable, final String otherActualTable) {
int index = -1;
// 找出其它实际表(子表或者父表)所在的位置(index)
for (TableRule each : tableRules) {
index = each.findActualTableIndex(dataSource, otherActualTable);
if (-1 != index) {
break;
}
}
Preconditions.checkState(-1 != index, String.format("Actual table [%s].[%s] is not in table config", dataSource, otherActualTable));
// 根据其他实际表所在的位置(index)获取当前位置的logicTable对应的实际表
for (TableRule each : tableRules) {
if (each.getLogicTable().equals(logicTable.toLowerCase())) {
return each.getActualDataNodes().get(index).getTableName().toLowerCase();
}
}
throw new IllegalStateException(String.format("Cannot find binding actual table, data source: %s, logic table: %s, other actual table: %s", dataSource, logicTable, otherActualTable));
}
// 所有逻辑表
Collection<String> getAllLogicTables() {
return Lists.transform(tableRules, new Function<TableRule, String>() {

@Override
public String apply(final TableRule input) {
return input.getLogicTable().toLowerCase();
}
});
}
}

路由解析

路由逻辑:

  1. shardingRule获取TableRule,如果当前表有绑定表且已经保存,则忽略当前表;
  2. 若没有绑定表或有绑定表但绑定的表未保存,则使用SimpleRoutingEngine路由,保存路由结果;
  3. 获取当前表的绑定表并保存;
  4. 如果只有一个SimpleRoutingEngine路由结果,就返回路由结果;
  5. 多个路由结果,就使用CartesianRoutingEngine)再路由。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@RequiredArgsConstructor
@Slf4j
public final class ComplexRoutingEngine implements RoutingEngine {

private final ShardingRule shardingRule;

private final List<Object> parameters;

private final Collection<String> logicTables;

private final SQLStatement sqlStatement;

@Override
public RoutingResult route() {
Collection<RoutingResult> result = new ArrayList<>(logicTables.size());
Collection<String> bindingTableNames = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
/**
* logicTables: customer_track,customer_track_detail
* customer_track会使用SimpleRoutingEngine().route()进行路由
* customer_track_detail因为和customer_track是绑定关系会按照customer_track的路由结果
**/
for (String each : logicTables) {
// 查找当前表的规则
Optional<TableRule> tableRule = shardingRule.tryFindTableRule(each);
if (tableRule.isPresent()) {
if (!bindingTableNames.contains(each)) {
// 绑定表的第一个表使用SimpleRoutingEngine路由,也就是使用设置的规则去路由
result.add(new SimpleRoutingEngine(shardingRule, parameters, tableRule.get().getLogicTable(), sqlStatement).route());
}
//查找绑定表规则
Optional<BindingTableRule> bindingTableRule = shardingRule.findBindingTableRule(each);
if (bindingTableRule.isPresent()) {
bindingTableNames.addAll(Lists.transform(bindingTableRule.get().getTableRules(), new Function<TableRule, String>() {

@Override
public String apply(final TableRule input) {
return input.getLogicTable();
}
}));
}
}
}
log.trace("mixed tables sharding result: {}", result);
if (result.isEmpty()) {
throw new ShardingJdbcException("Cannot find table rule and default data source with logic tables: '%s'", logicTables);
}
// 都是互相绑定的表,路由一致
if (1 == result.size()) {
return result.iterator().next();
}
// 使用笛卡尔路由
return new CartesianRoutingEngine(result).route();
}
}

笛卡尔路由引擎(CartesianRoutingEngine)

笛卡尔路由引擎的前提是有多个SimpleRoutingEngine路由结果

路由逻辑:
将所有的路由结果使用CartesianRoutingResult)按DataSource分类保存。

直接看CartesianRoutingEngine的route方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
public final class CartesianRoutingEngine implements RoutingEngine {
// 路由结果,构造器阶段传入
private final Collection<RoutingResult> routingResults;

@Override
public CartesianRoutingResult route() {
CartesianRoutingResult result = new CartesianRoutingResult();
// key是datasource名称,value是逻辑表名列表
for (Entry<String, Set<String>> entry : getDataSourceLogicTablesMap().entrySet()) {
// 根据逻辑表获取实际表名
List<Set<String>> actualTableGroups = getActualTableGroups(entry.getKey(), entry.getValue());
// 封装成TableUnit
List<Set<TableUnit>> tableUnitGroups = toTableUnitGroups(entry.getKey(), actualTableGroups);
// getCartesianTableReferences将TableUnit转成CartesianTableReference
// 合并详见CartesianRoutingResult
result.merge(entry.getKey(), getCartesianTableReferences(Sets.cartesianProduct(tableUnitGroups)));
}
log.trace("cartesian tables sharding result: {}", result);
return result;
}
// key是datasource名称,value是逻辑表名列表
private Map<String, Set<String>> getDataSourceLogicTablesMap() {
// 获取DataSource的交集
Collection<String> intersectionDataSources = getIntersectionDataSources();
// key是datasource名称,value是逻辑表名列表
Map<String, Set<String>> result = new HashMap<>(routingResults.size());
for (RoutingResult each : routingResults) {
for (Entry<String, Set<String>> entry : each.getTableUnits().getDataSourceLogicTablesMap(intersectionDataSources).entrySet()) {
if (result.containsKey(entry.getKey())) {
result.get(entry.getKey()).addAll(entry.getValue());
} else {
result.put(entry.getKey(), entry.getValue());
}
}
}
return result;
}
// 取routingResults的DataSource交集
private Collection<String> getIntersectionDataSources() {
Collection<String> result = new HashSet<>();
for (RoutingResult each : routingResults) {
if (result.isEmpty()) {
result.addAll(each.getTableUnits().getDataSourceNames());
}
// retainAll 取交集
result.retainAll(each.getTableUnits().getDataSourceNames());
}
return result;
}

private List<Set<String>> getActualTableGroups(final String dataSource, final Set<String> logicTables) {
List<Set<String>> result = new ArrayList<>(logicTables.size());
for (RoutingResult each : routingResults) {
result.addAll(each.getTableUnits().getActualTableNameGroups(dataSource, logicTables));
}
return result;
}

private List<Set<TableUnit>> toTableUnitGroups(final String dataSource, final List<Set<String>> actualTableGroups) {
List<Set<TableUnit>> result = new ArrayList<>(actualTableGroups.size());
for (Set<String> each : actualTableGroups) {
result.add(new HashSet<>(Lists.transform(new ArrayList<>(each), new Function<String, TableUnit>() {

@Override
public TableUnit apply(final String input) {
return findTableUnit(dataSource, input);
}
})));
}
return result;
}

private TableUnit findTableUnit(final String dataSource, final String actualTable) {
for (RoutingResult each : routingResults) {
Optional<TableUnit> result = each.getTableUnits().findTableUnit(dataSource, actualTable);
if (result.isPresent()) {
return result.get();
}
}
throw new IllegalStateException(String.format("Cannot found routing table factor, data source: %s, actual table: %s", dataSource, actualTable));
}

private List<CartesianTableReference> getCartesianTableReferences(final Set<List<TableUnit>> cartesianTableUnitGroups) {
List<CartesianTableReference> result = new ArrayList<>(cartesianTableUnitGroups.size());
for (List<TableUnit> each : cartesianTableUnitGroups) {
result.add(new CartesianTableReference(each));
}
return result;
}
}

卡迪尔路由结果(CartesianRoutingResult)

merge逻辑:

  1. 将routingTableReferences按照dataSource分类保存。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@ToString
public final class CartesianRoutingResult extends RoutingResult {

@Getter
private final List<CartesianDataSource> routingDataSources = new ArrayList<>();

void merge(final String dataSource, final Collection<CartesianTableReference> routingTableReferences) {
for (CartesianTableReference each : routingTableReferences) {
merge(dataSource, each);
}
}

private void merge(final String dataSource, final CartesianTableReference routingTableReference) {
// routingDataSources 中已有该dataSource,就添加表引用(routingTableReference)
for (CartesianDataSource each : routingDataSources) {
if (each.getDataSource().equalsIgnoreCase(dataSource)) {
each.getRoutingTableReferences().add(routingTableReference);
return;
}
}
// routingDataSources 中没有该dataSource
routingDataSources.add(new CartesianDataSource(dataSource, routingTableReference));
}

@Override
public boolean isSingleRouting() {
Collection<CartesianTableReference> cartesianTableReferences = new LinkedList<>();
for (CartesianDataSource cartesianDataSource : routingDataSources) {
for (CartesianTableReference cartesianTableReference : cartesianDataSource.getRoutingTableReferences()) {
cartesianTableReferences.add(cartesianTableReference);
}
}
return 1 == cartesianTableReferences.size();
}
}

CartesianRoutingResult继承了RoutingResult类,但并没有给RoutingResult的属性TableUnits tableUnits赋值,所以在实际使用CartesianRoutingResult类获取路由结果时,必然不能使用统一代码处理。

CartesianRoutingResult在ParsingSQLRouter.route()中单独处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@Override
public SQLRouteResult route(final String logicSQL, final List<Object> parameters, final SQLStatement sqlStatement) {
// ...省略无关代码

// route()执行了上文的路由过程,获得CartesianRoutingResult
RoutingResult routingResult = route(parameters, sqlStatement);
/**
* 这里CartesianRoutingResult单独处理
**/
if (routingResult instanceof CartesianRoutingResult) {
for (CartesianDataSource cartesianDataSource : ((CartesianRoutingResult) routingResult).getRoutingDataSources()) {
// cartesianTableReference中包含了一个db中的所有实际表名
for (CartesianTableReference cartesianTableReference : cartesianDataSource.getRoutingTableReferences()) {
result.getExecutionUnits().add(new SQLExecutionUnit(cartesianDataSource.getDataSource(), rewriteEngine.generateSQL(cartesianTableReference, sqlBuilder)));
}
}
} else { // 通用处理
for (TableUnit each : routingResult.getTableUnits().getTableUnits()) {
result.getExecutionUnits().add(new SQLExecutionUnit(each.getDataSourceName(), rewriteEngine.generateSQL(each, sqlBuilder)));
}
}
if (showSQL) {
SQLLogger.logSQL(logicSQL, sqlStatement, result.getExecutionUnits(), parameters);
}
return result;
}

sql改写(SQLRewriteEngine)

sql改写.png

sql改写发生在.ParsingSQLRouter.route()->rewriteEngine.rewrite()。

sql改写在SQLRewriteEngine类中,有两个核心方法:

  1. rewrite方法:负责将sql语句改写,并分割成块,结构例如:

    select * from, tableToken, where id=?
  2. generateSQL方法:负责生成改写后的sql语句,这里主要对表名进行替换。

SQLRewriteEngine改写了表名、limit的偏移量和行数。

SQLBuilder

先认识一下rewrite方法执行后保存sql段的类SQLBuilder。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
public final class SQLBuilder {
// 保存sql段
private final List<Object> segments;
// 用于保存不改变的sql
private StringBuilder currentSegment;

public SQLBuilder() {
segments = new LinkedList<>();
currentSegment = new StringBuilder();
segments.add(currentSegment);
}
// 保存不变的sql
public void appendLiterals(final String literals) {
currentSegment.append(literals);
}
// 保存table段
public void appendTable(final String tableName) {
segments.add(new TableToken(tableName));
currentSegment = new StringBuilder();
segments.add(currentSegment);
}
// 保存index段
public void appendIndex(final String indexName, final String tableName) {
segments.add(new IndexToken(indexName, tableName));
currentSegment = new StringBuilder();
segments.add(currentSegment);
}
// 生成改写后的sql语句
// tableTokens: <逻辑表名>:<实际表名>
// customer_track:customer_track_1
public String toSQL(final Map<String, String> tableTokens) {
// 保存最终的sql语句
StringBuilder result = new StringBuilder();
for (Object each : segments) {
// 替换表名
if (each instanceof TableToken && tableTokens.containsKey(((TableToken) each).tableName)) {
result.append(tableTokens.get(((TableToken) each).tableName));
} else if (each instanceof IndexToken) {
// 索引替换表名
IndexToken indexToken = (IndexToken) each;
result.append(indexToken.indexName);
String tableName = tableTokens.get(indexToken.tableName);
if (!Strings.isNullOrEmpty(tableName)) {
result.append("_");
result.append(tableName);
}
} else {
result.append(each);
}
}
return result.toString();
}
// ... 省略其它
}

SQLRewriteEngine.rewrite()
isRewriteLimit 由是否查询多个库表决定。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
public SQLBuilder rewrite(final boolean isRewriteLimit) {
SQLBuilder result = new SQLBuilder();
if (sqlTokens.isEmpty()) {
result.appendLiterals(originalSQL);
return result;
}
int count = 0;
sortByBeginPosition();
for (SQLToken each : sqlTokens) {
if (0 == count) {
// sql语句最前面一段不用修改的内容保存
result.appendLiterals(originalSQL.substring(0, each.getBeginPosition()));
}
if (each instanceof TableToken) {
appendTableToken(result, (TableToken) each, count, sqlTokens);
} else if (each instanceof IndexToken) {
appendIndexToken(result, (IndexToken) each, count, sqlTokens);
} else if (each instanceof ItemsToken) {
appendItemsToken(result, (ItemsToken) each, count, sqlTokens);
} else if (each instanceof RowCountToken) {
appendLimitRowCount(result, (RowCountToken) each, count, sqlTokens, isRewriteLimit);
} else if (each instanceof OffsetToken) {
appendLimitOffsetToken(result, (OffsetToken) each, count, sqlTokens, isRewriteLimit);
} else if (each instanceof OrderByToken) {
appendOrderByToken(result, count, sqlTokens);
}
count++;
}
return result;
}

SQLRewriteEngine.generateSQL()

1
2
3
public String generateSQL(final TableUnit tableUnit, final SQLBuilder sqlBuilder) {
return sqlBuilder.toSQL(getTableTokens(tableUnit));
}

表名改写(TableToken)

sql的表名替换在生成sql语句时替换SQLRewriteEngine.generateSQL(),改写规则:将逻辑表替换成实际的物理表。
详见上面:

SQLRewriteEngine.generateSQL() -> SQLBuilder.toSQL()

索引改写(IndexToken)

索引改写就是将逻辑表索引改为实际物理表索引index_tableName,所以逻辑与表名改写一致。

分页改写

1
2
3
4
5
6
7
8
@Override
public SQLRouteResult route(final String logicSQL, final List<Object> parameters, final SQLStatement sqlStatement) {
//...
if (sqlStatement instanceof SelectStatement && null != ((SelectStatement) sqlStatement).getLimit()) {
processLimit(parameters, (SelectStatement) sqlStatement, isSingleRouting);
}
//...
}

分页如果在limit中有占位符就会在ParsingSQLRouter.route() -> ParsingSQLRouter.processLimit()中作前置处理。

Limit带占位符前置处理

processLimit()如下:

1
2
3
4
5
6
7
8
9
10
11
private void processLimit(final List<Object> parameters, final SelectStatement selectStatement, final boolean isSingleRouting) {
// 只查一个库表时
if (isSingleRouting) {
selectStatement.setLimit(null);
return;
}
// 是否查所有,有group by/聚合函数且分组字段与排序字段不一致
boolean isNeedFetchAll = (!selectStatement.getGroupByItems().isEmpty() || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems();
// 改写
selectStatement.getLimit().processParameters(parameters, isNeedFetchAll);
}

改写在Limit类中,parameters占位符对应的参数,改写逻辑如下:

  1. 行数根据isFetchAll,true设置行数为最大Integer;不然要是数据库是MySQL/H2/PostgreSQL则设置成Offset + rowCount;其他就是参数传多少就是多少;
  2. 偏移量直接设置成0;
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
public void processParameters(final List<Object> parameters, final boolean isFetchAll) {
// 将limit占位的值获取,并设置到offset和rowCount
fill(parameters);
// 重写
rewrite(parameters, isFetchAll);
}
private void fill(final List<Object> parameters) {
int offset = 0;
if (null != this.offset) {
// offset.getIndex() !=-1,取对应占位的参数值
offset = -1 == this.offset.getIndex() ? getOffsetValue() : NumberUtil.roundHalfUp(parameters.get(this.offset.getIndex()));
this.offset.setValue(offset);
}
int rowCount = 0;
if (null != this.rowCount) {
// rowCount.getIndex() !=-1,取对应占位的参数值
rowCount = -1 == this.rowCount.getIndex() ? getRowCountValue() : NumberUtil.roundHalfUp(parameters.get(this.rowCount.getIndex()));
this.rowCount.setValue(rowCount);
}
if (offset < 0 || rowCount < 0) {
throw new SQLParsingException("LIMIT offset and row count can not be a negative value.");
}
}

private void rewrite(final List<Object> parameters, final boolean isFetchAll) {
int rewriteOffset = 0;
int rewriteRowCount;
// 行数设置最大Integer,也就是不分页
if (isFetchAll) {
rewriteRowCount = Integer.MAX_VALUE;
} else if (isNeedRewriteRowCount()) {
// 行数设置成 Offset + rowCount
rewriteRowCount = null == rowCount ? -1 : getOffsetValue() + rowCount.getValue();
} else {
rewriteRowCount = rowCount.getValue();
}
// 偏移量offset 直接设置成0
if (null != offset && offset.getIndex() > -1) {
parameters.set(offset.getIndex(), rewriteOffset);
}
if (null != rowCount && rowCount.getIndex() > -1) {
parameters.set(rowCount.getIndex(), rewriteRowCount);
}
}
public boolean isNeedRewriteRowCount() {
return DatabaseType.MySQL == databaseType || DatabaseType.PostgreSQL == databaseType || DatabaseType.H2 == databaseType;
}

行数改写(RowCountToken)

改写limit的行数在SQLRewriteEngine.rewrite(),改写规则如下:

  1. 只查一个库表的不改写;
  2. 有group by或聚合函数且分组字段与排序字段不相同,则设置行数为Integer.MAX_VALUE,即数据库分页;
  3. 有group by或聚合函数且分组字段与排序字段相同,行数 = offset+行数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
private void appendLimitRowCount(final SQLBuilder sqlBuilder, final RowCountToken rowCountToken, 
final int count, final List<SQLToken> sqlTokens, final boolean isRewrite) {
SelectStatement selectStatement = (SelectStatement) sqlStatement;
Limit limit = selectStatement.getLimit();
// isRewrite 是根据是否需要查询多个库表决定的
if (!isRewrite) { // 不改写limit
sqlBuilder.appendLiterals(String.valueOf(rowCountToken.getRowCount()));
} else if ((!selectStatement.getGroupByItems().isEmpty() || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems()) {
// 有group by 或 聚合函数并且group by字段与排序字段不相同,设置行数为最大值Integer.MAX_VALUE
sqlBuilder.appendLiterals(String.valueOf(Integer.MAX_VALUE));
} else {
// isNeedRewriteRowCount 只要MySQL/PostgreSQL/H2就需要,到这里的条件是group by字段与排序字段相同
sqlBuilder.appendLiterals(String.valueOf(limit.isNeedRewriteRowCount() ? rowCountToken.getRowCount() + limit.getOffsetValue() : rowCountToken.getRowCount()));
}
int beginPosition = rowCountToken.getBeginPosition() + String.valueOf(rowCountToken.getRowCount()).length();
appendRest(sqlBuilder, count, sqlTokens, beginPosition);
}

偏移量(Offset)改写(OffsetToken)

改写limit的偏移量(Offset)在SQLRewriteEngine.rewrite(),改写规则如下:

  1. 需要查询多个库表就直接把offset设为0
1
2
3
4
5
6
7
8
private void appendLimitOffsetToken(final SQLBuilder sqlBuilder, final OffsetToken offsetToken, 
final int count, final List<SQLToken> sqlTokens, final boolean isRewrite) {
// isRewrite 是根据是否需要查询多个库表决定的
// 有多个库表查询就直接把offset设为0
sqlBuilder.appendLiterals(isRewrite ? "0" : String.valueOf(offsetToken.getOffset()));
int beginPosition = offsetToken.getBeginPosition() + String.valueOf(offsetToken.getOffset()).length();
appendRest(sqlBuilder, count, sqlTokens, beginPosition);
}

排序(OrderByToken)

在SQLRewriteEngine.rewrite()中没有对排序进行改写,appendOrderByToken()根据解析的SelectStatement生成order by的语句。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
private void appendOrderByToken(final SQLBuilder sqlBuilder, final int count, final List<SQLToken> sqlTokens) {
SelectStatement selectStatement = (SelectStatement) sqlStatement;
StringBuilder orderByLiterals = new StringBuilder();
orderByLiterals.append(" ").append(DefaultKeyword.ORDER).append(" ").append(DefaultKeyword.BY).append(" ");
int i = 0;
// 添加order by的字段
for (OrderItem each : selectStatement.getOrderByItems()) {
String columnLabel = SQLUtil.getOriginalValue(each.getColumnLabel(), databaseType);
if (0 == i) {
orderByLiterals.append(columnLabel).append(" ").append(each.getType().name());
} else {
orderByLiterals.append(",").append(columnLabel).append(" ").append(each.getType().name());
}
i++;
}
orderByLiterals.append(" ");
// order by 的sql段加入到sqlBuilder中
sqlBuilder.appendLiterals(orderByLiterals.toString());
int beginPosition = ((SelectStatement) sqlStatement).getGroupByLastPosition();
appendRest(sqlBuilder, count, sqlTokens, beginPosition);
}

结果归并

结果归并在ShardingPreparedStatement.executeQuery()->MergeEngine.merge()中执行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@Override
public ResultSet executeQuery() throws SQLException {
ResultSet result;
try {
// 路由完成
Collection<PreparedStatementUnit> preparedStatementUnits = route();
// 执行sql
List<ResultSet> resultSets = new PreparedStatementExecutor(
getConnection().getShardingContext().getExecutorEngine(), routeResult.getSqlStatement().getType(), preparedStatementUnits, getParameters()).executeQuery();
// 结果归并
result = new ShardingResultSet(resultSets, new MergeEngine(resultSets, (SelectStatement) routeResult.getSqlStatement()).merge(), this);
} finally {
clearBatch();
}
currentResultSet = result;
return result;
}

无归并(IteratorStreamResultSetMerger)

没有排序、group、聚合函数的查询会封装成IteratorStreamResultSetMerger。

IteratorStreamResultSetMerger中依次迭代结果集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
public final class IteratorStreamResultSetMerger extends AbstractStreamResultSetMerger {

private final Iterator<ResultSet> resultSets;

public IteratorStreamResultSetMerger(final List<ResultSet> resultSets) {
this.resultSets = resultSets.iterator();
setCurrentResultSet(this.resultSets.next());
}

@Override
public boolean next() throws SQLException {
if (getCurrentResultSet().next()) {
return true;
}
if (!resultSets.hasNext()) {
return false;
}
// 迭代结果集
setCurrentResultSet(resultSets.next());
boolean hasNext = getCurrentResultSet().next();
if (hasNext) {
return true;
}
while (!hasNext && resultSets.hasNext()) {
setCurrentResultSet(resultSets.next());
hasNext = getCurrentResultSet().next();
}
return hasNext;
}
}

排序归并(OrderByStreamResultSetMerger)

OrderByStreamResultSetMerger序列图.png

排序逻辑

  1. 利用优先级队列,比较各结果集当前行排序字段值,按比较结果排列结果集;
  2. 取队列头的结果集的一行数据;
  3. 运行next()方法,将队列头的结果集取出队列,执行该结果集的next()方法使光标指向下一行,再将该结果集存入队列排序;

属性

1
2
3
4
5
6
7
8
9
public class OrderByStreamResultSetMerger extends AbstractStreamResultSetMerger {
// 排序字段信息
private final List<OrderItem> orderByItems;

// 结果集队列
private final Queue<OrderByValue> orderByValuesQueue;

private boolean isFirstNext;
}

构造器

构造器中将结果集以OrderByValue类的形式存入优先级队列中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public OrderByStreamResultSetMerger(final List<ResultSet> resultSets, final 			List<OrderItem> orderByItems) throws SQLException {
this.orderByItems = orderByItems;
this.orderByValuesQueue = new PriorityQueue<>(resultSets.size());
// 存入优先级队列
orderResultSetsToQueue(resultSets);
isFirstNext = true;
}
private void orderResultSetsToQueue(final List<ResultSet> resultSets) throws SQLException {
// 遍历存入队列
for (ResultSet each : resultSets) {
OrderByValue orderByValue = new OrderByValue(each, orderByItems);
if (orderByValue.next()) {
orderByValuesQueue.offer(orderByValue);
}
}
// 设置当前结果集为队列头元素
setCurrentResultSet(orderByValuesQueue.isEmpty() ? resultSets.get(0) : orderByValuesQueue.peek().getResultSet());
}

next()方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@Override
public boolean next() throws SQLException {
if (orderByValuesQueue.isEmpty()) {
return false;
}
if (isFirstNext) {
isFirstNext = false;
return true;
}
// 取出队列头部元素
OrderByValue firstOrderByValue = orderByValuesQueue.poll();
// 光标移动
if (firstOrderByValue.next()) {
// 存入队列,重排序
orderByValuesQueue.offer(firstOrderByValue);
}
if (orderByValuesQueue.isEmpty()) {
return false;
}
// 设置当前结果集
setCurrentResultSet(orderByValuesQueue.peek().getResultSet());
return true;
}

OrderByValue类

  1. OrderByStreamResultSetMerger的优先级队列通过OrderByValue实现的Comparable接口来判断优先级;
  2. OrderByValue的比较的逻辑是依次比较排序字段的值的大小,第一个不相等的排序字段大小代表了OrderByValue的大小(两个结果集中的光标指向行数据依次比较排序字段);
  3. OrderByValue的next()方法会在OrderByStreamResultSetMerger的next()方法中被调用;
  4. OrderByValue的next()会推动结果集(resultSet.next())指向下一行。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    public final class OrderByValue implements Comparable<OrderByValue> {

    @Getter
    private final ResultSet resultSet;
    // 排序字段信息
    private final List<OrderItem> orderByItems;
    // 一行数据中包含排序字段的字段值
    private List<Comparable<?>> orderValues;

    // next方法中会getOrderValues()获取当前行的排序字段值
    public boolean next() throws SQLException {
    // 调用结果集的next()方法
    boolean result = resultSet.next();
    orderValues = result ? getOrderValues() :
    Collections.<Comparable<?>>emptyList();
    return result;
    }
    // 遍历orderByItems排序字段,获取对应的字段值
    private List<Comparable<?>> getOrderValues() throws SQLException {
    List<Comparable<?>> result = new ArrayList<>(orderByItems.size());
    for (OrderItem each : orderByItems) {
    Object value = resultSet.getObject(each.getIndex());
    Preconditions.checkState(null == value || value instanceof Comparable, "Order by value must implements Comparable");
    result.add((Comparable<?>) value);
    }
    return result;
    }
    // 挨个比较排序字段的值,有一个不相等就返回
    @Override
    public int compareTo(final OrderByValue o) {
    for (int i = 0; i < orderByItems.size(); i++) {
    OrderItem thisOrderBy = orderByItems.get(i);
    int result = ResultSetUtil.compareTo(orderValues.get(i), o.orderValues.get(i), thisOrderBy.getType(), thisOrderBy.getNullOrderType());
    if (0 != result) {
    return result;
    }
    }
    return 0;
    }
    }

聚合函数结果归并

GroupByStreamResultSetMerger(group by 字段与排序字段相同)

GroupByStreamResultSetMerger.png

属性

1
2
3
4
5
6
7
8
9
10
public final class GroupByStreamResultSetMerger extends OrderByStreamResultSetMerger {
// key字段名,value字段所在位置
private final Map<String, Integer> labelAndIndexMap;
// sql
private final SelectStatement selectStatement;
// 当前行数据(归并处理后的数据)
private final List<Object> currentRow;
// group by字段值,用于归并判断,在更新currentRow数据后指向下一行
private List<?> currentGroupByValues;
}

构造器

1
2
3
4
5
6
7
8
9
10
public GroupByStreamResultSetMerger(
final Map<String, Integer> labelAndIndexMap, final List<ResultSet> resultSets, final SelectStatement selectStatement) throws SQLException {
// OrderByStreamResultSetMerger 排序
super(resultSets, selectStatement.getOrderByItems());
this.labelAndIndexMap = labelAndIndexMap;
this.selectStatement = selectStatement;
currentRow = new ArrayList<>(labelAndIndexMap.size());
// 获得group by 字段值
currentGroupByValues = getOrderByValuesQueue().isEmpty() ? Collections.emptyList() : new GroupByValue(getCurrentResultSet(), selectStatement.getGroupByItems()).getGroupValues();
}

读取数据

1
2
3
4
5
6
7
8
9
10
@Override
public Object getValue(final int columnIndex, final Class<?> type) throws SQLException {
return currentRow.get(columnIndex - 1);
}

@Override
public Object getValue(final String columnLabel, final Class<?> type) throws SQLException {
Preconditions.checkState(labelAndIndexMap.containsKey(columnLabel), String.format("Can't find columnLabel: %s", columnLabel));
return currentRow.get(labelAndIndexMap.get(columnLabel) - 1);
}

next()方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@Override
public boolean next() throws SQLException {
currentRow.clear();
if (getOrderByValuesQueue().isEmpty()) {
return false;
}
if (isFirstNext()) {
super.next();
}
// 归并
if (aggregateCurrentGroupByRowAndNext()) {
// currentGroupByValues 指向下一行的group by 字段值
currentGroupByValues = new GroupByValue(getCurrentResultSet(), selectStatement.getGroupByItems()).getGroupValues();
}
return true;
}
// 归并
private boolean aggregateCurrentGroupByRowAndNext() throws SQLException {
boolean result = false;
// 聚集函数字段
Map<AggregationSelectItem, AggregationUnit> aggregationUnitMap = Maps.toMap(selectStatement.getAggregationSelectItems(), new Function<AggregationSelectItem, AggregationUnit>() {

@Override
public AggregationUnit apply(final AggregationSelectItem input) {
return AggregationUnitFactory.create(input.getType());
}
});
//currentGroupByValues当前的group字段值,与各行group字段比较,直到group字段值不一致
//第一次比较是自己比较自己
//super.next()推动结果集走向下一行,同时保持排序
while (currentGroupByValues.equals(new GroupByValue(getCurrentResultSet(), selectStatement.getGroupByItems()).getGroupValues())) {
// 聚集函数结果归并,例如count()函数就相加,avg就平均等
aggregate(aggregationUnitMap);
// 更新当前行的值,次数聚集函数的值不是归并后的值
cacheCurrentRow();
result = super.next();
if (!result) {
break;
}
}
// 此时将聚集函数归并后的值更新到相应的字段
setAggregationValueToCurrentRow(aggregationUnitMap);
return result;
}
private void aggregate(final Map<AggregationSelectItem, AggregationUnit> aggregationUnitMap) throws SQLException {
for (Entry<AggregationSelectItem, AggregationUnit> entry : aggregationUnitMap.entrySet()) {
List<Comparable<?>> values = new ArrayList<>(2);
if (entry.getKey().getDerivedAggregationSelectItems().isEmpty()) {
values.add(getAggregationValue(entry.getKey()));
} else {
for (AggregationSelectItem each : entry.getKey().getDerivedAggregationSelectItems()) {
values.add(getAggregationValue(each));
}
}
// 不同的聚集函数不同的merge
entry.getValue().merge(values);
}
}

GroupByMemoryResultSetMerger (group by字段与排序字段不一致)

归并逻辑
在构造器阶段将所有数据归并,遍历所有数据:

  1. 将group字段值相同的聚集函数字段进行归并;
  2. 将归并过的数据通过排序字段排序到List中,List中的元素代表一行数据;
  3. 获取List的迭代器memoryResultSetRows,通过迭代器next()。

属性

1
2
3
4
5
6
public final class GroupByMemoryResultSetMerger extends AbstractMemoryResultSetMerger {

private final SelectStatement selectStatement;
// MemoryResultSetRow 存储归并过的一行数据
private final Iterator<MemoryResultSetRow> memoryResultSetRows;
}

构造器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
public GroupByMemoryResultSetMerger(
final Map<String, Integer> labelAndIndexMap, final List<ResultSet> resultSets, final SelectStatement selectStatement) throws SQLException {
super(labelAndIndexMap);
this.selectStatement = selectStatement;
// 初始化
memoryResultSetRows = init(resultSets);
}
private Iterator<MemoryResultSetRow> init(final List<ResultSet> resultSets) throws SQLException {
// value代表归并后数据
Map<GroupByValue, MemoryResultSetRow> dataMap = new HashMap<>(1024);
// key:group字段值,value:<聚合函数字段,聚合函数信息>
Map<GroupByValue, Map<AggregationSelectItem, AggregationUnit>> aggregationMap = new HashMap<>(1024);
// 遍历所有数据
for (ResultSet each : resultSets) {
while (each.next()) {
// group字段值
GroupByValue groupByValue = new GroupByValue(each, selectStatement.getGroupByItems());
// GroupByValue 为key是否已经存入dataMap和aggregationMap,没有就初始化存入
initForFirstGroupByValue(each, groupByValue, dataMap, aggregationMap);
// 聚集函数字段合并
aggregate(each, groupByValue, aggregationMap);
}
}
// 将归并过的数据aggregationMap更新到dataMap中
setAggregationValueToMemoryRow(dataMap, aggregationMap);
// 对归并数据按排序字段排序
List<MemoryResultSetRow> result = getMemoryResultSetRows(dataMap);
if (!result.isEmpty()) {
setCurrentResultSetRow(result.get(0));
}
return result.iterator();
}
private void initForFirstGroupByValue(final ResultSet resultSet, final GroupByValue groupByValue, final Map<GroupByValue, MemoryResultSetRow> dataMap,
final Map<GroupByValue, Map<AggregationSelectItem, AggregationUnit>> aggregationMap) throws SQLException {
if (!dataMap.containsKey(groupByValue)) {
dataMap.put(groupByValue, new MemoryResultSetRow(resultSet));
}
if (!aggregationMap.containsKey(groupByValue)) {
Map<AggregationSelectItem, AggregationUnit> map = Maps.toMap(selectStatement.getAggregationSelectItems(), new Function<AggregationSelectItem, AggregationUnit>() {

@Override
public AggregationUnit apply(final AggregationSelectItem input) {
return AggregationUnitFactory.create(input.getType());
}
});
aggregationMap.put(groupByValue, map);
}
}
private void aggregate(final ResultSet resultSet, final GroupByValue groupByValue, final Map<GroupByValue, Map<AggregationSelectItem, AggregationUnit>> aggregationMap) throws SQLException {
for (AggregationSelectItem each : selectStatement.getAggregationSelectItems()) {
List<Comparable<?>> values = new ArrayList<>(2);
if (each.getDerivedAggregationSelectItems().isEmpty()) {
values.add(getAggregationValue(resultSet, each));
} else {
for (AggregationSelectItem derived : each.getDerivedAggregationSelectItems()) {
values.add(getAggregationValue(resultSet, derived));
}
}
// groupByValue相等的行的聚合函数字段(each)归并
aggregationMap.get(groupByValue).get(each).merge(values);
}
}

next()方法
执行迭代器的next()方法。

1
2
3
4
5
6
7
8
@Override
public boolean next() throws SQLException {
if (memoryResultSetRows.hasNext()) {
setCurrentResultSetRow(memoryResultSetRows.next());
return true;
}
return false;
}

Mysql分页归并(LimitDecoratorResultSetMerger)

LimitDecoratorResultSetMerger是个包装类,可以包装任意归并处理类。

LimitDecoratorResultSetMerger处理分页逻辑是在构造器时调用Offset次被包装的ResultSetMerger对象的next()跳过对应行数据;取数据时通过自身的next()方法限制获取的数据条数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public final class LimitDecoratorResultSetMerger extends AbstractDecoratorResultSetMerger {

private final Limit limit;

private final boolean skipAll;

private int rowNumber;

public LimitDecoratorResultSetMerger(final ResultSetMerger resultSetMerger, final Limit limit) throws SQLException {
super(resultSetMerger);
this.limit = limit;
// 直接next() Offset次跳过
skipAll = skipOffset();
}
private boolean skipOffset() throws SQLException {
for (int i = 0; i < limit.getOffsetValue(); i++) {
if (!getResultSetMerger().next()) {
return true;
}
}
rowNumber = 0;
return false;
}
@Override
public boolean next() throws SQLException {
if (skipAll) {
return false;
}
if (limit.getRowCountValue() < 0) {
return getResultSetMerger().next();
}
// 限制总条数
return ++rowNumber <= limit.getRowCountValue() && getResultSetMerger().next();
}
}

参考

核心概念
sharding-jdbc分库分表规则(2)-多表查询