본문 바로가기
AI/AI 부트캠프

[AI 부트캠프] DAY 83 - 딥러닝 프로젝트 8

by HOHHOH 2023. 11. 17.

[오늘의 일지]

딥러닝 프로젝트 - 최종 모델 선정

[상세 내용]

최종 모델 선정

추가 기능 

- 이제 프로젝트 기간이 얼마 남지 않았습니다. 그래서 저희 조는 마지막으로 모델을 선정하고 어느 정도의 결과를 도출해 내야 했으며 몇 가지 기능을 수정해 가며 개선된 결과도 찾아야만 했습니다. 그래서 기존 일지에서 올렸던 LSTM 코드를 기반으로 몇 가지 기능을 추가했습니다. 첫 번째는 피처 엔지니어링을 통해서 주요 피처들의 60초 후의 값들을 같은 row에 추가시켜 주는 것을 shift 함수를 이용하여 구현했습니다. 캐글에서 참고했던 코드들은 shift를 60초만 한 게 아니라 10초, 20초, 30초,... , 60초까지 더 생성하는 것이었기 때문에 그냥 60초만 생성해 주는 코드로 변경해서 사용했습니다. 그리고 LSTM 모델에 Attention이라는 기능을 추가해서 피처 별 가중치를 설정해 줘서 모델이 예측하는데 성능을 올릴 수 있게 도와줬습니다. 코드로 보여드리겠습니다.

 

- Feature Engineering

# shift를 통해 t+60 관련 피처 추가
def add_historic_features(df, cols, shifts=6, add_first=True):
    for col in cols:
        grouped_vals = df[["stock_id", "date_id", col]].groupby(["stock_id", "date_id"])
        fill_value = df[col].mean()

        # Shifted column with 6 shifts
        df[col+"_shift_"+ str(shifts)] = grouped_vals[col].transform(lambda x: x.shift(shifts)).fillna(fill_value)

        if add_first:
            df = df.merge(grouped_vals.first().reset_index(), on=["date_id", "stock_id"], suffixes=["", "_first"])
    return df

def fillmean(df, cols):
    for col in cols:
        mean_val = df[col].mean()
        df[col] = df[col].fillna(mean_val)
    return df
    

# 관련성이 있을거 같은 피처를 조합해서 추가
def add_info_columns(raw_df):
    df = raw_df.copy()

    df[["reference_price", "far_price","near_price","bid_price","ask_price","wap"]] = df[["reference_price", "far_price","near_price","bid_price","ask_price","wap"]].fillna(1.0)
    df = fillmean(df, ["imbalance_size", "matched_size"])

    df['imbalance_ratio'] = df['imbalance_size'] / (df['matched_size'] + 1.0e-8)
    df["imbalance"] = df["imbalance_size"] * df["imbalance_buy_sell_flag"]

    df['ordersize_imbalance'] = (df['bid_size']-df['ask_size']) / ((df['bid_size']+df['ask_size'])+1.0e-8)
    df['matching_imbalance'] = (df['imbalance_size']-df['matched_size']) / ((df['imbalance_size']+df['matched_size'])+1.0e-8)

    df = add_historic_features(df, ["imbalance","imbalance_ratio","reference_price","wap","matched_size","far_price","near_price"], shifts=6, add_first=True)

    return df

 

- Attention 기능 추가

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.attn_weights = nn.Linear(hidden_size * 2, 1)

    def forward(self, lstm_output, last_hidden_state):
        combined = torch.cat((lstm_output, last_hidden_state.unsqueeze(1).repeat(1, lstm_output.size(1), 1)), dim=2)
        attn_weights = F.softmax(self.attn_weights(combined), dim=1)
        attn_output = torch.sum(attn_weights * lstm_output, dim=1)
        return attn_output

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout_rate):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.attention = Attention(hidden_size)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        lstm_out, (h_n, c_n) = self.lstm(x, (h0, c0))
        attn_output = self.attention(lstm_out, h_n[-1, :, :])
        out = self.fc(attn_output)

        return out

# Hyperparameters
input_size = 33  # Number of features
hidden_size = 128
num_layers = 4
output_size = 1
dropout_rate = 0.25

 

참고 (현재 상황: 데이터 업그레이드로 혼란 상황)

- 현재 캐글에서 test 데이터의 API를 업그레이들 한 상황이기 때문에 어제 일지에서 언급했던 'submission scoring error'라던지 'notebook timeout error'가 발생하고 있는 것 같습니다. 또한 기존에 그냥 올렸던 코드들도 오류로 인해 더 이상 제출이 안되고 있는 상황이라 대회의 디스커션을 보면 많은 사람들이 당황해하고 있는 거처럼 보였습니다. 뿐만 아니라 데이터 이슈로 인해 리더보드 상에서 순위 업그레이드도 잘 안되고 있는 거처럼 보입니다. 시간이 차츰 지나면 누군가 상황을 해결하겠지만 초보자입장인 저에게는 처음 겪는 상황이라 당황스럽습니다.

 

[마무리]

 오늘은 최종 모델의 형태를 구축하였습니다. 그런데 캐글의 데이터 상의 업그레이드 이슈 때문에 제출이 불가해진 상황에서 결과를 도출해서 인사이트를 얻어야 프로젝트가 어느 정도 잘 마무리될 거 같은데 안되고 있어서 매우 당황스러운 상황입니다. 주말까지만 해결해서 결과를 얻을 수 있으면 좋을 거 같은 바람이 있습니다.

반응형

댓글