들어가며
본 포스팅에서는 서로소 집합과 구현 방식에 대해 소개합니다.
📌 서로소 집합이란?
서로소 집합(Disjoint Sets)란 공통 원소가 없는 두 집합을 의미합니다. 예를 들어 {1, 2}와 {3, 4}는 서로소 관계이지만, {1, 2}와 {2, 3}은 서로소 관계가 아닙니다.
서로소 집합 자료구조는 union과 find라는 2개의 연산이 이루어집니다.
union(합집합)이란, 하나의 집합으로 합치는 연산을 의미하며, find(찾기) 연산은 특정 원소가 어느 집합에 속하였는지를 찾아내는 연산입니다.
1️⃣ union(합집합) 연산을 통해, 서로 연결된 두 개의 노드를 확인합니다.
- 1 ) 노드 A와 노드 B의 루트 노드인 A'와 B'를 찾습니다.
- 2 ) 루트 노드 A'를 루트 노드 B'의 부모 노드로 설정합니다. (B'가 A'를 가리키게 됩니다.)
2️⃣ 모든 union(합집합) 연산을 처리하여 위의 과정이 반복됩니다.
1. 노드의 개수만큼 부모 테이블을 초기화합니다.
노드 번호 | 1 | 2 | 3 | 4 | 5 | 6 |
부모 | 1 | 2 | 3 | 4 | 5 | 6 |
2. union 1, 4의 경우, 노드 1과 4의 루트 노드를 찾습니다. 현재 루트 노드는 1과 4이므로 더 큰 번호인 4의 부모를 1로 설정합니다.
노드 번호 | 1 | 2 | 3 | 4 | 5 | 6 |
부모 | 1 | 2 | 3 | 1 | 5 | 6 |
3. union 2, 3의 경우, 노드 2와 3이 루트 노드를 찾습니다. 현재 루트 노드는 2와 3이므로 더 큰 번호인 3의 부모를 2로 설정합니다.
노드 번호 | 1 | 2 | 3 | 4 | 5 | 6 |
부모 | 1 | 2 | 2 | 1 | 5 | 6 |
4. union 2, 4의 경우, 노드 2와 4의 루트 노드를 찾습니다. 현재 루트 노드는 2와 1이므로 더 큰 번호인 2의 부모를 1로 설정합니다.
노드 번호 | 1 | 2 | 3 | 4 | 5 | 6 |
부모 | 1 | 1 | 2 | 1 | 5 | 6 |
5. union 5, 6의 경우, 노드 5와 6의 루트 노드를 찾습니다. 현재 루트 노드는 5와 6이므로 더 큰 번호인 6의 부모를 5로 설정합니다.
노드 번호 | 1 | 2 | 3 | 4 | 5 | 6 |
부모 | 1 | 1 | 2 | 1 | 5 | 5 |
이렇게 union 연산을 모두 수행하고 난 모습은 위와 같습니다.
📌 기본적인 서로소 집합 알고리즘
def find_parent(parent, x):
if parent[x] != x:
return find_parent(parent, parent[x])
else:
return x
루트 노드를 찾기 위해 재귀적으로 함수를 호출하여 루트 노드를 찾아냅니다.
def union_parent(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
else:
parent[a] = b
노드 a와 b의 루트 노드를 찾아, 더 큰 번호의 루트 노드를 변경해줍니다.
✅ [ CODE ]
import sys
input = sys.stdin.readline
def find_parent(parent, x):
if parent[x] != x:
return find_parent(parent, parent[x])
else:
return x
def union_parent(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
else:
parent[a] = b
v, e = map(int, input().split())
parent = [i for i in range(v + 1)]
print(parent)
for _ in range(e):
a, b = map(int, input().split())
union_parent(parent, a, b)
print("각 원소가 속한 집합 : ", end=" ")
for i in range(1, v + 1):
print(find_parent(parent, i), end = " ")
print()
print("부모 테이블 : ", end=" ")
for i in range(1, v + 1):
print(parent[i], end=" ")
☑️ 시간 복잡도
루트 노드를 구하기 위해선 부모 노드를 확인하며 거슬러 올라가야만 합니다. 예를 들어 노드 3의 부모 노드는 2이며, 노드의 2의 부모 노드는 1이며 최종적으로 루트 노드는 1이 됩니다. 따라서, find 연산을 하려면, 모든 노드를 확인하기 때문에 최악의 경우, 시간 복잡도는 $O(V)$가 됩니다.
노드의 개수가 V이며, union의 연산 횟수를 M이라고 할 때, 최종적으로 시간 복잡도는 $O(VM)$이 됩니다.
이를 개선하기 위해 루트 노드를 빠르게 알아내어 더 효율적인 방식으로 시간 복잡도를 개선하는 방식으로는 경로 압축 기법이 있습니다.
📌 경로 압축을 이용한 서로소 집합
노드 6의 루트 노드를 찾기 위해선, 6 → 5 → 4 → 3 → 2 → 1로 부모 노드를 거슬러 올라가야 합니다. 따라서 최대 노드의 개수만큼의 연산이 필요하게 됩니다.
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
재귀적으로 호출하여 루트 노드를 바로 부모 노드로 변경해주어, 루트 노드에 빠르게 접근할 수 있도록 합니다.
✅ [ CODE ]
import sys
input = sys.stdin.readline
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
def union_parent(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
else:
parent[a] = b
v, e = map(int, input().split())
parent = [i for i in range(v + 1)]
for _ in range(e):
a, b = map(int, input().split())
union_parent(parent, a, b)
print("각 원소가 속한 집합 : ", end=" ")
for i in range(1, v + 1):
print(find_parent(parent, i), end=" ")
print()
print("부모 테이블 : ", end=" ")
for i in range(1, v + 1):
print(parent[i], end=" ")
📌 서로소 집합을 활용한 사이클 판별
무방향 그래프 내에서의 사이클을 판별할 때 서로소 집합을 사용해볼 수 있습니다.
1️⃣ 각 간선을 확인하며 두 노드의 루트 노드를 확인합니다.
- 1 ) 루트 노드가 서로 다르다면, 두 노드에 대하여 union 연산을 수행합니다.
- 2 ) 루트 노드가 서로 같다면, 사이클이 존재한다는 것을 의미합니다.
2️⃣ 그래프에 포함되어 있는 모든 가선에 대해 위의 과정을 반복합니다.
1. 노드의 개수만큼 부모 테이블을 초기화합니다.
노드 번호 | 1 | 2 | 3 |
부모 | 1 | 2 | 3 |
2. 간선 (1, 2)를 확인하며 각 루트 노드는 1과 2입니다. 더 큰 번호인 2의 부모 노드를 1로 변경합니다.
노드 번호 | 1 | 2 | 3 |
부모 | 1 | 1 | 3 |
3. 간선 (1, 3)을 확인하면 각 루트 노드는 1과 3입니다. 더 큰 번호인 3의 부모 노드를 1로 변경합니다.
노드 번호 | 1 | 2 | 3 |
부모 | 1 | 1 | 1 |
4, 간선 (2, 3)을 확인하면 각 루트 노드는 1과 1입니다. 동일한 노드 1을 루트 노드로 가지고 있으므로, 사이클이 존재한다는 것을 의미합니다.
✅ [ CODE ]
"""
< 서로소 집합 - 사이클 판별 >
"""
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
def union_parent(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
else:
parent[a] = b
v, e = map(int, input().split())
parent = [0] * (v + 1)
for i in range(1, v + 1):
parent[i] = i
cycle = False
for i in range(e):
a, b = map(int, input().split())
if find_parent(parent, a) == find_parent(parent, b):
cycle = True
break
else:
union_parent(parent, a, b)
if cycle:
print("사이클 발생")
else:
print("사이클 노 발생")
포스팅 내용에 오류 및 문제가 있는 경우, 댓글로 남겨주시면 감사하겠습니다.
출처 ) 이것이 코딩 테스트다 with Python