題目大意
如果一個(gè) N × N 的矩陣滿足:
- 矩陣每行均為 [1, N] 的正整數(shù)的一個(gè)排列
- 矩陣內(nèi)所有元素與其上方的元素不同
那么這個(gè)矩陣便是美麗的。
現(xiàn)給定一個(gè) N × N 的美麗矩陣,求有多少個(gè) N × N 的美麗矩陣比它小。(矩陣從上到下按行比較)
題目保證 N 不超過 2000
分析
這個(gè)題的切入點(diǎn)在于美麗矩陣的定義。如果我們把當(dāng)前行看作待排序的一個(gè)序列,上面一行當(dāng)成排序基準(zhǔn),則這個(gè)問題可以轉(zhuǎn)化成錯(cuò)位排序問題。但是不同的是,在排了一部分?jǐn)?shù)字以后,剩下的部分的排序標(biāo)準(zhǔn)就不那么嚴(yán)苛了(即存在一些可行數(shù)沒有禁止位置)。
若 i 表示序列的長度, j 表示存在禁止位置的元素個(gè)數(shù),則由容斥原理易得:
這個(gè)表達(dá)式非常優(yōu)美,但是我們需要求 O(N2) 個(gè) dp 值,如果直接計(jì)算的話需要 O(N3) 。不能承受??紤]到組合遞推關(guān)系:
我們猜想 dp[i][j] 可以由 dp[i][j - 1] 和 dp[i - 1][j - 1] 推出。果然,我們有:
現(xiàn)在我們來解決這個(gè)問題。根據(jù)題目的定義,兩個(gè)矩陣的比較與兩個(gè)字符串的比較方式類似,如果 A 矩陣小于 B 矩陣,那么 A 矩陣的任意“前綴”小于等于 B 矩陣的對(duì)應(yīng)“前綴”。如果兩個(gè)矩陣的第一個(gè)不相同元素的位置為 (i, j) ,那么對(duì)于給定的 B 矩陣,這樣的 A 矩陣共有
其中 way0 和 way1 分別表示有多少種選法使得 A[i][j] < B[i][j] 且是否選取 A[i - 1] 中在 j 位置以前出現(xiàn)過的元素; cnt 表示 A[i - 1] 的前 j 個(gè)元素與 A[i] 的前 (j - 1) 個(gè)元素的相同個(gè)數(shù)。
如果我們用樹狀數(shù)組或名次樹來滑動(dòng)地維護(hù) way0 和 way1 ,則均攤時(shí)間復(fù)雜度可降為每個(gè)位置 O(logN) 。剪枝以后可以接受。
代碼
總復(fù)雜度為 O(n2log(n))
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template <typename T>
using ordered_set = tree<T, null_type, less<T>, rb_tree_tag,
tree_order_statistics_node_update>;
typedef long long ll;
typedef pair<int, int> pii;
#define FOR(i, a, b) for (int (i) = (a); (i) <= (b); (i)++)
#define ROF(i, a, b) for (int (i) = (a); (i) >= (b); (i)--)
#define REP(i, n) FOR(i, 0, (n)-1)
#define sqr(x) ((x) * (x))
#define all(x) (x).begin(), (x).end()
#define reset(x, y) memset(x, y, sizeof(x))
#define uni(x) (x).erase(unique(all(x)), (x).end());
#define BUG(x) cerr << #x << " = " << (x) << endl
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define _1 first
#define _2 second
const int maxn = 2123;
const ll MOD = 998244353;
ll fac[maxn], dp[maxn][maxn], ans, D[maxn];
int n, a[maxn][maxn];
pii way[maxn][maxn];
int main() {
scanf("%d", &n);
fac[0] = 1;
FOR(i, 1, n) fac[i] = fac[i - 1] * i % MOD;
dp[0][0] = 1;
FOR(i, 1, n) {
dp[i][0] = fac[i];
FOR(j, 1, i) {
dp[i][j] = (dp[i][j - 1] - dp[i - 1][j - 1]) % MOD;
if (dp[i][j] < 0) dp[i][j] += MOD;
}
}
D[0] = 1;
FOR(i, 1, n) D[i] = D[i - 1] * dp[n][n] % MOD;
FOR(i, 1, n) FOR(j, 1, n) scanf("%d", &a[i][j]);
FOR(i, 1, n) {
ordered_set<int> s[2];
FOR(j, 1, n) s[1].insert(j);
FOR(j, 1, n) {
way[i][j]._1 = s[0].order_of_key(a[i][j]);
if (a[i - 1][j] < a[i][j] && s[0].find(a[i - 1][j]) != s[0].end())
way[i][j]._1--;
way[i][j]._2 = s[1].order_of_key(a[i][j]);
if (a[i - 1][j] < a[i][j] && s[1].find(a[i - 1][j]) != s[1].end())
way[i][j]._2--;
s[0].erase(a[i][j]), s[1].erase(a[i][j]);
if (s[1].find(a[i - 1][j]) != s[1].end()) {
s[1].erase(a[i - 1][j]);
s[0].insert(a[i - 1][j]);
}
}
}
FOR(i, 1, n)
ans = (ans + way[1][i]._2 * fac[n - i] % MOD * D[n - 1]) % MOD;
FOR(i, 2, n) {
unordered_map<int, int> m;
FOR(j, 1, n) {
m[a[i - 1][j]]++;
int cnt = 2 * j - 1 - m.size();
ans = (ans + way[i][j]._1 * dp[n - j][n - 2 * j + cnt + 1]
% MOD * D[n - i]) % MOD;
ans = (ans + way[i][j]._2 * dp[n - j][n - 2 * j + cnt]
% MOD * D[n - i]) % MOD;
m[a[i][j]]++;
}
}
printf("%lld", ans);
}