ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [Python] union find (disjoint-set) 알고리즘
    언어/파이썬 & 장고 2019. 10. 21. 00:16

    union find (disjoint-set) 이란?

    서로 중복되지 않는 부분 집합들로 나눠진 원소들에 대한 정보를 저장하고 조작하는 자료 구조입니다. 간단하게 다수의 노드들 중에 연결된 노드를 찾거나 노드들을 합칠때 사용하는 알고리즘입니다. 아래 예시 그림을 보면 어떤 개념인지 바로 이해가 됩니다.

    원소(노드)가 1~10까지 있다고 할 때, 해당 원소들은 아래와 같은 리스트의 구조를 갖습니다. 첫 번째 행은 원소의 번호고 두 번째 행은 원소의 관계로 보면 됩니다. 

    12345678910
    12345678910


    이러한 원소들은 아래와 같은 연결 구조를 갖고 있다고 가정합니다.



    위 그림은 원소 1,2,3,4 / 5,6,7,8 / 9,10 과 같이 3개의 그룹으로 묶여 있습니다. 

    이를 리스트로 표현하면 아래와 같습니다.

    12345678910
    1123556799


    위와 같이 두 번째 행인 연결정보를 갱신할 때는 작은 값을 기준으로 갱신합니다. 예를 들어, 원소 1과 2가 연결되었다고 할 때는 원소 2의 정보를 1로 바꿉니다. 1,2,3,4 모두 연결되어 있어 같은 집단에 속하지만 위의 테이블에서는 원소 3과 4의 연결 정보는 1이 아닙니다. 이를 처리하기 위해선 순차적으로 진행이 되어야 합니다.

    1. 원소 3과 2는 연결 되어 있음
    2. 원소 3의 연결 정보를 2로 변경
    3. 여기서 원소 2와 1은 연결되어 있음
    4. 삼단논법으로 인해 원소 3의 연결정보 2는 원소 1과 연결되어 있음을 알 수 있음
    5. 원소 3의 연결정보를 1로 변경


    위와 같은 순서를 진행하면 결국 아래와 같은 테이블이 결과로 나오게 됩니다.

    12345678910
    1111555599


    구현

    union find(disjoint-set)의 핵심은 아래 3가지 입니다.

    • 초기화 : N 개의 원소가 각각의 집합에 포함되어 있도록 초기화

    • Union 연산 : 두 원소 a, b 가 주어질 때, 이들이 속한 두 집합을 하나로 합침

    • Find 연산 : 어떤 원소 a 가 주어질 때, 이 원소가 속한 집합을 반환


    이 알고리즘을 구현하는 방식은 배열을 사용하는 것과 트리구조를 사용하는 방식 두 가지가 있습니다. 보통은 효율성 때문에 배열로 사용하지 않고 트리구조를 사용합니다.

    1. 배열을 사용하는 방식

    배열을 사용하여 disjoint-set을 구현하는 방법은 아래와 같습니다.

    1. 먼저 원소의 크기만큼 배열을 초기화(생성)해줍니다.
    2. 그 다음 두 원소를 합치기 위해 배열의 모든 원소를 순회하면서 하나의 번호를 나머지 하나로 교체합니다. (두 원소들을 합칠 때, 해당 원소의 연결정보를 찾는 연산은 한 번 만에 알 수 있습니다.)


    위에서도 설명했듯이 합치는 연산(union)에서 배열의 모든 원소를 순회하게 되기 때문에 시간복잡도가 O(N)입니다. 이를 해결하기 위해 보통 트리구조를 많이 사용합니다.

    예제

    disjoint-set을 검색하면 보통 재귀함수를 사용한 예제코드를 많이 볼 수 있는데, 해당 코드의 문제점은 연결 정보 순서가 바뀌면 동작에 버그가 있다는 점입니다.

    • disjoint(1,2), disjoiint(2,3)은 정상으로 동작하지만 이 순서를 바꾸면 오류가 발생

    따라서 아래 예제는 재귀 함수를 지양하고 반복문을 사용한 코드입니다.

    class DisjointSet:
        def __init__(self, n):
            self.data = list(range(n))
            self.size = n
    
        def find(self, index):
            return self.data[index]
    
    
    	def union(self, x, y):
        	x, y = self.find(x), self.find(y)
    
    	    if x == y:
        	    return
    
    	    for i in range(self.size):
        	    if self.find(i) == y:
            	    self.data[i] = x
    
    
        @property
        def length(self):
            return len(set(self.data))
    
    
    
    
    disjoint = DisjointSet(10)
    
    disjoint.union(0, 1)
    disjoint.union(1, 2)
    disjoint.union(2, 3)
    disjoint.union(4, 5)
    disjoint.union(5, 6)
    disjoint.union(6, 7)
    disjoint.union(8, 9)
    
    print(disjoint.data)
    print(disjoint.length)
    
    
    # [0, 0, 0, 0, 4, 4, 4, 4, 8, 8]
    # 3

    2. 트리구조를 사용하는 방식

    트리 구조에서도 union-by-size, union-by-height, path comprehension 과 같이 3가지 방식이 있습니다.

    union-by-size, union-by-height

    임의의 두 집합을 합칠 때는 원소의 수가 적은 집합을 원소가 많은 집합의 하위 트리로 합치는 것이 효율적입니다. 이러한 방식이 union-by-size입니다.  이와 마찬가지로 트리의 높이가 작은 집합을 높이가 더 큰 집합의 서브트리로 합치는 방식을 union-by-height라고 합니다.

    union-by-size의 방법은 아래와 같습니다. (union-by-height는 3번을 제외하고 동일합니다.)

    1. 주어진 원소의 개수만큼 사용하지 않을 값 (예제에서는 -1)을 생성
    2. 루트노드의 인덱스를 찾음
    3. 루트노드의 인덱스가 다르다면 리스트의 값이 더 낮은(size가 더 큰) 것을 찾아서 큰 것에 더해줌
    4. 작은건 큰것의 인덱스로 바꿔준다.


    트리의 경우, 높이는 logn이므로 시간복잡도는 O(logn)을 가집니다.

    예제

    union-by-size
    class DisjointSet:
        def __init__(self, n):
            self.data = [-1 for _ in range(n)]
            self.size = n
    
        def find(self, index):
            value = self.data[index]
            if value < 0:
                return index
    
            return self.find(value)
    
        def union(self, x, y):
            x = self.find(x)
            y = self.find(y)
    
            if x == y:
                return
    
            if self.data[x] < self.data[y]:
                self.data[x] += self.data[y]
                self.data[y] = x
            else:
                self.data[y] += self.data[x]
                self.data[x] = y
    
            self.size -= 1
    
    
    disjoint = DisjointSet(10)
    
    disjoint.union(0, 1)
    disjoint.union(1, 2)
    disjoint.union(2, 3)
    disjoint.union(4, 5)
    disjoint.union(5, 6)
    disjoint.union(6, 7)
    disjoint.union(8, 9)
    
    print(disjoint.data)
    print(disjoint.size)
    
    
    # [1, -4, 1, 1, 5, -4, 5, 5, 9, -2]
    # 3
    union-by-height
    class DisjointSet:
        def __init__(self, n):
            self.data = [-1] * n
            self.size = n
    
        def find(self, index):
            value = self.data[index]
            if value < 0:
                return index
    
            return self.find(value)
    
        def union(self, x, y):
            x = self.find(x)
            y = self.find(y)
    
            if x == y:
                return
    
            if self.data[x] < self.data[y]:
                self.data[y] = x
            elif self.data[x] > self.data[y]:
                self.data[x] = y
            else:
                self.data[x] -= 1
                self.data[y] = x
    
            self.size -= 1
    
    
    disjoint = DisjointSet(10)
    
    disjoint.union(0, 1)
    disjoint.union(1, 2)
    disjoint.union(2, 3)
    disjoint.union(4, 5)
    disjoint.union(5, 6)
    disjoint.union(6, 7)
    disjoint.union(8, 9)
    
    print(disjoint.data)
    print(disjoint.size)
    
    
    
    # -2의 경우 루트 인덱스이므로 세지 않음
    # [-2, 0, 0, 0, -2, 4, 4, 4, -2, 8]
    # 3

    path comprehension

    위에서 설명한 union-by-size나 union-by-height는 find() 연산을 수행할 때 트리의 높이만큼 올라가 루트를 찾을 수 있는데, 이러한 비효율성을 해결하고자 나온 개념입니다. path compression을 수행하면 루트를 찾는 find() 연산 비용을 낮출 수 있습니다. 이는 find()를 실행한 뒤에 다음 find()에서 효과적으로 찾을 수 있도록 트리를 재구성 하는 방식입니다.


    방식은 아래와 같습니다.

    1. find(x)를 실행
    2. x가 루트가 아니라면 임시로 저장
    3. 루트노드를 찾을 때 까지 재귀함수 반복
    4. 루트노드를 찾으면 x를 루트노드의 자식으로 표시

    예제

    class DisjointSet:
        def __init__(self, n):
            self.data = [-1 for _ in range(n)]
            self.size = n
    
        def upward(self, change_list, index):
            value = self.data[index]
            if value < 0:
                return index
    
            change_list.append(index)
            return self.upward(change_list, value)
    
        def find(self, index):
            change_list = []
            result = self.upward(change_list, index)
    
            for i in change_list:
                self.data[i] = result
    
            return result
    
        def union(self, x, y):
            x = self.find(x)
            y = self.find(y)
    
            if x == y:
                return
    
            if self.data[x] < self.data[y]:
                self.data[y] = x
            elif self.data[x] > self.data[y]:
                self.data[x] = y
            else:
                self.data[x] -= 1
                self.data[y] = x
    
            self.size -= 1
    
    
    disjoint = DisjointSet(10)
    
    disjoint.union(0, 1)
    disjoint.union(1, 2)
    disjoint.union(2, 3)
    disjoint.union(4, 5)
    disjoint.union(5, 6)
    disjoint.union(6, 7)
    disjoint.union(8, 9)
    
    print(disjoint.data)
    print(disjoint.size)
    
    
    
    
    # [-2, 0, 0, 0, -2, 4, 4, 4, -2, 8]
    # 3



    만약 직접 구현을 할 필요가 없고 (실무에서 사용하기 위해) 기능 자체만 사용하고 싶으면 PyPi에서 모듈을 직접 받아서 사용할 수 있습니다. (https://github.com/mrapacz/disjoint-set)

    댓글