Hanbit the Developer

[C] Prim-Jarnik Algorithm using Priorty Queue 본문

Algorithm

[C] Prim-Jarnik Algorithm using Priorty Queue

hanbikan 2021. 11. 26. 19:36
#include <stdio.h>
#include <stdlib.h>
#include <limits.h>
#pragma warning(disable:4996)

typedef struct Node {
	int value;
	struct Node* next;
} Node;

typedef struct Vertex {
	Node* header;
} Vertex;

typedef struct Edge {
	int v1;
	int v2;
	int w;
} Edge;

typedef struct Metadata {
	Vertex** vertices;
	int verticesCount;
	Edge** edges;
	int edgesCount;
} Metadata;

typedef struct Pair {
	int a, b;
} Pair;

typedef struct PriortyQueue {
	Pair* data;
	int count;
} PriortyQueue;

void swapItems(PriortyQueue* pqueue, int* locator, int index1, int index2);
void upHeap(PriortyQueue* pqueue, int* locator, int index);
void downHeap(PriortyQueue* pqueue, int* locator, int index);
void setElement(PriortyQueue* pqueue, int* locator, int index, int n);
int getOppositeIndex(Edge* edge, int vertexIndex);
Pair pop(PriortyQueue* pqueue, int* locator);
void printMST(Metadata* mt, int start);
Edge* getEdge(Metadata* mt, int index1, int index2, int weight);
Node* getNode(int n);
void addEdgeToVertex(Metadata* mt, int edgeIndex, int vertexIndex);
void freeGraph(Metadata mt, int n, int m);

Node* getNode(int n) {
	Node* res = malloc(sizeof(Node));
	res->value = n;
	res->next = NULL;

	return res;
}

Edge* getEdge(Metadata* mt, int index1, int index2, int weight) {
	Edge* res = malloc(sizeof(Edge));
	res->v1 = index1;
	res->v2 = index2;
	res->w = weight;

	return res;
}

void addEdgeToVertex(Metadata* mt, int edgeIndex, int vertexIndex) {
	Node* next = mt->vertices[vertexIndex]->header->next;
	Node* newNode = getNode(edgeIndex);	// Node holds the index of edges
	mt->vertices[vertexIndex]->header->next = newNode;
	newNode->next = next;
}


void printMST(Metadata* mt, int start) {
	int i;
	int vCount = mt->verticesCount;

	// Initialize dists
	int* dists = malloc(sizeof(int) * (vCount + 1));
	for (i = 1; i < vCount + 1; i++) dists[i] = INT_MAX;
	dists[start] = 0;

	// Initialize isVisited
	int* isVisited = malloc(sizeof(int) * (vCount + 1));
	for (i = 1; i < vCount + 1; i++) isVisited[i] = 0;

	// Initialize priorty queue: (dist, vertexIndex) 
	PriortyQueue *pqueue = (PriortyQueue *)malloc(sizeof(PriortyQueue));
	pqueue->data = (Pair*)malloc(sizeof(Pair) * (vCount + 1));
	for (i = 1; i < vCount + 1; i++) {
		pqueue->data[i].a = INT_MAX;
		pqueue->data[i].b = i;
	}
	pqueue->data[start].a = 0;
	pqueue->count = 1;

	int* locator = malloc(sizeof(int) * (vCount + 1));
	for (i = 1; i < vCount + 1; i++) locator[i] = i;

	// Prim-Jarnik
	while (pqueue->count >= 1) {
		Pair top = pop(pqueue, locator);
		isVisited[top.b] = 1;
		printf(" %d", top.b);

		// Visit adjacent nodes
		Node* cur = mt->vertices[top.b]->header->next;
		while (cur) {
			int opposite = getOppositeIndex(mt->edges[cur->value], top.b);

			if (mt->edges[cur->value]->w < dists[opposite] && isVisited[opposite] == 0) {
				dists[opposite] = mt->edges[cur->value]->w;

				if (locator[opposite] > pqueue->count) {
					pqueue->count++;
					swapItems(pqueue, locator, locator[opposite], pqueue->count);
				}
				setElement(pqueue, locator, locator[opposite], dists[opposite]);
			}

			cur = cur->next;
		}
	}
	printf("\n");

	// Print sum of dists
	int res = 0;
	for (i = 2; i < vCount + 1; i++) res += dists[i];
	printf("%d", res);

	// Free			
	free(dists);
	free(isVisited);
	free(pqueue->data);
	free(pqueue);
	free(locator);
}

Pair pop(PriortyQueue *pqueue, int* locator) {
	Pair res = pqueue->data[1];
	pqueue->data[1].a = INT_MAX;
	swapItems(pqueue, locator, 1, pqueue->count);
	pqueue->count -= 1;
	downHeap(pqueue, locator, 1);

	return res;
}

void downHeap(PriortyQueue* pqueue, int* locator, int index) {
	if (index * 2 > pqueue->count) return;

	int child = index * 2;
	if (child + 1 <= pqueue->count && pqueue->data[child].a > pqueue->data[child + 1].a) child += 1;

	if (pqueue->data[index].a <= pqueue->data[child].a) return;

	swapItems(pqueue, locator, index, child);
	downHeap(pqueue, locator, child);
}

void swapItems(PriortyQueue* pqueue, int* locator, int index1, int index2) {
	locator[pqueue->data[index1].b] = index2;
	locator[pqueue->data[index2].b] = index1;

	Pair tmp = pqueue->data[index1];
	pqueue->data[index1] = pqueue->data[index2];
	pqueue->data[index2] = tmp;
}

int getOppositeIndex(Edge* edge, int vertexIndex) {
	int res = edge->v1;
	if (edge->v1 == vertexIndex) res = edge->v2;

	return res;
}


void setElement(PriortyQueue* pqueue, int* locator, int index, int n) {
	pqueue->data[index].a = n;
	if (index > 1 && pqueue->data[index / 2].a > pqueue->data[index].a) upHeap(pqueue, locator, index);
	else downHeap(pqueue, locator, index);
}


void upHeap(PriortyQueue* pqueue, int* locator, int index) {
	if (index <= 1) return;

	int parent = index / 2;
	if (pqueue->data[parent].a <= pqueue->data[index].a) return;

	swapItems(pqueue, locator, index, parent);
	upHeap(pqueue, locator, parent);
}


void freeGraph(Metadata mt, int n, int m) {
	int i;

	for (i = 1; i < n + 1; i++) {
		Node* cur = mt.vertices[i]->header;
		while (cur) {
			Node* next = cur->next;
			free(cur);
			cur = next;
		}
		free(mt.vertices[i]);
	}
	free(mt.vertices);

	for (i = 0; i < m; i++) free(mt.edges[i]);
	free(mt.edges);
}

int main() {
	int i;

	int n, m;
	scanf("%d %d", &n, &m);

	Metadata mt;
	mt.verticesCount = n;
	mt.vertices = malloc(sizeof(Vertex*) * (n + 1));
	for (i = 1; i < n + 1; i++) {
		mt.vertices[i] = malloc(sizeof(Vertex));
		mt.vertices[i]->header = getNode(NULL);
	}
	mt.edges = malloc(sizeof(Edge*) * m);
	mt.edgesCount = m;

	for (i = 0; i < m; i++) {
		int a, b, w;
		scanf("%d %d %d", &a, &b, &w);
		mt.edges[i] = getEdge(&mt, a, b, w);
		addEdgeToVertex(&mt, i, a);
		addEdgeToVertex(&mt, i, b);
	}

	// Solution
	printMST(&mt, 1);

	// Free
	freeGraph(mt, n, m);

	return 0;
}

 

1. 특정 노드 방문 처리(isVisited), 우선순위큐에서 빼냄

2. 해당 노드의 인접 노드들을 돌면서, '거리가 저장된 dists값보다 작으면서, 그곳에 방문하지 않았을 경우'에, dist값 갱신, 인접노드의 우선순위큐의 값 갱신을 해줌

 

위 루틴을 우선순위큐가 빌 때까지 반복해준다.

아래 케이스를 순차적으로 그려가면서 분석한 것이 큰 도움이 되었다.

 

7 10
1 2 40
1 3 80
2 4 10
2 3 90
3 4 60
3 6 30
3 7 50
4 5 100
5 6 70
6 7 20

 

1->2->4->3->6->7->5 순서로 방문하게 되며, 230의 비용이 든다.

 

isVisited와, dists 및 pqueue는 서순이 다르다는 점에 유의해야한다. isVisited에서 경로가 완전히 확정된 현재 노드의 값을 1(true)로 지정해준 뒤에, dists 및 pqueue를 현재 노드와 인접한 노드들을 기준으로 업데이트하는 것이다. 이후에 이 정보를 기준으로 min값을 가져와서 현재 노드를 정해주게 되는 것이고, 이런 식으로 반복된다.