본 포스팅은 패스트캠퍼스 환급 챌린지 참여를 위해 작성하였습니다.

https://abit.ly/lisbva

 

공부 시작

강의 종료

강의장

학습 인증샷

학습후기

 

ch01-10 코드실습 2 DGL 라이브러리, 3 GraphSAGE

 

DGL이란?

복잡한 그래프 연산을 쉽게 짜게 도와주는 PyTorch 기반 GNN 프레임워크

 

복잡한 그래프 연산을 한 줄로 처리 가능

샘플링, 배치 처리 등 연산 자동화

다양한 GNN 모델 템플릿 제공

 

 

PyTorch 기본 구현

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn

class DGLGraphSAGE(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim * 2, out_dim)  # 자기 정보 + 이웃 정보 concat 후 projection

    def forward(self, g, h):
        # 1. 노드 임베딩을 그래프에 등록
        g.ndata['h'] = h

        # 2. 메시지 전달 및 aggregation 수행
        # fn.copy_u: 각 노드의 'h' 값을 이웃에게 메시지로 복사
        # fn.mean: 받은 메시지들을 평균내어 'h_neigh'라는 이름으로 저장
        g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))

        # 3. 자기 자신의 임베딩과 이웃 평균 임베딩을 concat
        h_self = g.ndata['h']
        h_neigh = g.ndata['h_neigh']
        h_concat = torch.cat([h_self, h_neigh], dim=1)

        # 4. 선형변환 후 활성화 함수 적용
        return F.relu(self.linear(h_concat))

DGL 버전 구현

import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F

class DGLGraphSAGE(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # 이웃 임베딩 + 자기 임베딩을 concat한 후 projection할 선형층
        self.linear = nn.Linear(in_dim * 2, out_dim)

    def forward(self, g, h):
        # [PyTorch였다면]
        # - 각 노드마다 이웃을 수동으로 순회하면서,
        # - 메시지를 보내고, 평균을 내고, concat을 계산해야 함
        # - 반복문, 평균, 조건 분기 등 수작업이 들어감

        # [DGL에선] 아래 3줄로 끝남:

        # 1. 노드 임베딩을 그래프에 등록 (노드 피처로 저장)
        g.ndata['h'] = h

        # 2. 메시지 전달 및 aggregation을 한 줄로 처리
        #    - fn.copy_u: 각 노드의 'h'를 이웃에게 메시지로 보냄
        #    - fn.mean: 받은 메시지들을 평균내서 'h_neigh'로 저장
        g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))

        # 3. 자기 임베딩과 이웃 임베딩을 concat 후 projection
        h_self = g.ndata['h']
        h_neigh = g.ndata['h_neigh']
        h_concat = torch.cat([h_self, h_neigh], dim=1)
        return F.relu(self.linear(h_concat))

        # 결과적으로:
        # [PyTorch 로직 수십 줄] → [DGL로 3~4줄]  
        # 반복문 없이 전체 그래프에 대해 병렬 연산 처리됨