#include <bits/stdc++.h>
using namespace std;using ii=pair<int,int>;using ll=long long;void _o(){cerr<<endl;}template<class H,class...T>void _o(H h,T...t){cerr<<' '<<h;_o(t...);}
#define debug(...)cerr<<'['<<#__VA_ARGS__<<"]:",_o(__VA_ARGS__)
#define rep(i,a,b) for (auto i=(a); i<=(b); ++i)
#define all(x) (x).begin(),(x).end()
#define size(x) int((x).size())
#define fi first
#define se second
다음은 다익스트라를 포함하는 그래프 템플릿이다:
template <const int V> struct graph {
vector<ii> adj[V];
void clear(int v){ rep(i,0,v) adj[i].clear(); }
void add_edge(int s, int t, int w){ adj[s].push_back({t,w}), adj[t].push_back({s,w}); }
void dijkstra(vector<int> source, int dist[V]){
fill(dist,dist+V,1e9); priority_queue<ii> pq;
for (int s : source) pq.push({0,s}), dist[s] = 0;
while (not empty(pq)) {
auto [d,a] = pq.top(); pq.pop(); d=-d;
if (d != dist[a]) continue;
for (auto [b,w] : adj[a]) {
if (dist[b] > d+w) {
dist[b] = d+w;
pq.push({-dist[b],b});
}
}
}
}
};
다음은 특별한 변형없는 서로소 집합 템플릿이고, smaller to larger + path compression을 사용했다:
template <const int N> struct disjoint_set {
int par[N], sz[N];
disjoint_set() { clear(N-1); }
void clear(int n){ iota(par,par+n+1,0), fill(sz,sz+n+1,1); }
int find(int x){
if (x == par[x]) return x;
return par[x] = find(par[x]);
}
bool merge(int a, int b){
a = find(a), b = find(b);
if (a == b) return false;
if (sz[a] > sz[b]) swap(a,b);
sz[b] += sz[a], par[a] = b;
return true;
}
int set_size(int x){ return sz[find(x)]; }
bool same(int a, int b){ return find(a) == find(b); }
};
마지막으로 트리를 관리하는 템플릿이다. 경로 쿼리인데 역원이 존재하지 않는 연산이므로, DFS ordering + RMQ가 불가능하다. 그리고, 경로 위 간선들의 값을 계산해야 하므로 ETT + RMQ를 쓰는거는 좀 바보같다. 따라서 Sparse table 써주면 된다:
template <const int V> struct static_tree {
vector<int> adj[V];
int par[V], lev[V], kth[V][21], val[V][21];
static_tree() { fill(val[0],val[V],1e9); }
void add_edge(int s, int t){ adj[s].push_back(t), adj[t].push_back(s); }
void dfs(int s = 1, int e = 0){
lev[s] = lev[e]+1, par[s] = e;
for (int u : adj[s]) if (u != e) dfs(u,s);
}
void build(int v, int value[]){
dfs();
rep(i,1,v) val[i][0] = value[i];
rep(i,1,v) kth[i][0] = par[i];
rep(j,1,20) rep(i,1,v) {
kth[i][j] = kth[kth[i][j-1]][j-1];
val[i][j] = min(val[i][j-1],val[kth[i][j-1]][j-1]);
}
}
int query(int a, int b){
int r = 1e9;
if (lev[a] > lev[b]) swap(a,b);
rep(i,0,20) if ((lev[b]-lev[a])>>i&1) r = min(r,val[b][i]), b = kth[b][i];
if (a == b) return min(r,val[a][0]);
for (int i = 20; i >= 0; --i) {
if (kth[a][i] != kth[b][i]) {
r = min({r,val[a][i],val[b][i]});
a = kth[a][i], b = kth[b][i];
}
}
return min({r,val[a][1],val[b][0]});
}
};
메인 로직이다:
const int N = 1e5+3;
int n, m, k, q, value[N];
graph<N> g;
disjoint_set<N> dsu;
static_tree<N> st;
int main() {
cin.tie(0)->sync_with_stdio(0);
cin >> n >> m >> k >> q;
rep(i,1,m) {
int s, t, d;
cin >> s >> t >> d;
g.add_edge(s,t,d);
}
vector<int> cities(k);
for (int &p : cities) cin >> p;
g.dijkstra(cities,value);
bool considered[N] {0,};
vector<int> nodes(n); iota(all(nodes),1);
sort(all(nodes),[](int x, int y){return value[x] > value[y];});
for (int u : nodes) {
considered[u] = true;
for (auto [v,w] : g.adj[u]) if (considered[v]) {
if (dsu.merge(u,v)) st.add_edge(u,v);
}
}
st.build(n,value);
rep(i,1,q) {
int s, t; cin >> s >> t;
cout << st.query(s,t) << '\n';
}
}
#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for (auto i = (a); i <= (b); ++i)
template <typename T, const int V, const int E>
struct graph {
int u[E], v[E]; T w[E];
int next[E], start[V], p = 1;
void add_edge(int x, int y, T z = 1) {
u[p] = x, v[p] = y, w[p] = z;
next[p] = start[x], start[x] = p++;
}
#define for_adj(g,s,u,w) for (int o = g.start[s], u, w; \
tie(u,w) = make_pair(g.v[o],g.w[o]), o != 0; o = g.next[o])
};
const int N = 103, M = 103;
int n, m, a, b, dist[N];
bool visited[N];
graph<char,N,M> g;
int main() {
cin.tie(0)->sync_with_stdio(0);
cin >> n >> a >> b >> m;
rep(i,1,m) {
int x, y;
cin >> x >> y;
g.add_edge(x,y);
g.add_edge(y,x);
}
queue<int> q;
fill(dist,dist+N,1e9), q.push(a), dist[a] = 0;
while (not empty(q)) {
int s = q.front(); q.pop();
for_adj(g,s,u,w) {
if (visited[u]) continue;
visited[u] = true;
dist[u] = dist[s]+1;
q.push(u);
}
}
if (dist[b] < 1e9) cout << dist[b];
else cout << -1;
}