[오늘의 일지]
딥러닝 프로젝트 - 최종 모델 선정
[상세 내용]
최종 모델 선정
추가 기능
- 이제 프로젝트 기간이 얼마 남지 않았습니다. 그래서 저희 조는 마지막으로 모델을 선정하고 어느 정도의 결과를 도출해 내야 했으며 몇 가지 기능을 수정해 가며 개선된 결과도 찾아야만 했습니다. 그래서 기존 일지에서 올렸던 LSTM 코드를 기반으로 몇 가지 기능을 추가했습니다. 첫 번째는 피처 엔지니어링을 통해서 주요 피처들의 60초 후의 값들을 같은 row에 추가시켜 주는 것을 shift 함수를 이용하여 구현했습니다. 캐글에서 참고했던 코드들은 shift를 60초만 한 게 아니라 10초, 20초, 30초,... , 60초까지 더 생성하는 것이었기 때문에 그냥 60초만 생성해 주는 코드로 변경해서 사용했습니다. 그리고 LSTM 모델에 Attention이라는 기능을 추가해서 피처 별 가중치를 설정해 줘서 모델이 예측하는데 성능을 올릴 수 있게 도와줬습니다. 코드로 보여드리겠습니다.
- Feature Engineering
# shift를 통해 t+60 관련 피처 추가
def add_historic_features(df, cols, shifts=6, add_first=True):
for col in cols:
grouped_vals = df[["stock_id", "date_id", col]].groupby(["stock_id", "date_id"])
fill_value = df[col].mean()
# Shifted column with 6 shifts
df[col+"_shift_"+ str(shifts)] = grouped_vals[col].transform(lambda x: x.shift(shifts)).fillna(fill_value)
if add_first:
df = df.merge(grouped_vals.first().reset_index(), on=["date_id", "stock_id"], suffixes=["", "_first"])
return df
def fillmean(df, cols):
for col in cols:
mean_val = df[col].mean()
df[col] = df[col].fillna(mean_val)
return df
# 관련성이 있을거 같은 피처를 조합해서 추가
def add_info_columns(raw_df):
df = raw_df.copy()
df[["reference_price", "far_price","near_price","bid_price","ask_price","wap"]] = df[["reference_price", "far_price","near_price","bid_price","ask_price","wap"]].fillna(1.0)
df = fillmean(df, ["imbalance_size", "matched_size"])
df['imbalance_ratio'] = df['imbalance_size'] / (df['matched_size'] + 1.0e-8)
df["imbalance"] = df["imbalance_size"] * df["imbalance_buy_sell_flag"]
df['ordersize_imbalance'] = (df['bid_size']-df['ask_size']) / ((df['bid_size']+df['ask_size'])+1.0e-8)
df['matching_imbalance'] = (df['imbalance_size']-df['matched_size']) / ((df['imbalance_size']+df['matched_size'])+1.0e-8)
df = add_historic_features(df, ["imbalance","imbalance_ratio","reference_price","wap","matched_size","far_price","near_price"], shifts=6, add_first=True)
return df
- Attention 기능 추가
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.hidden_size = hidden_size
self.attn_weights = nn.Linear(hidden_size * 2, 1)
def forward(self, lstm_output, last_hidden_state):
combined = torch.cat((lstm_output, last_hidden_state.unsqueeze(1).repeat(1, lstm_output.size(1), 1)), dim=2)
attn_weights = F.softmax(self.attn_weights(combined), dim=1)
attn_output = torch.sum(attn_weights * lstm_output, dim=1)
return attn_output
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout_rate):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.attention = Attention(hidden_size)
self.dropout = nn.Dropout(p=dropout_rate)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
lstm_out, (h_n, c_n) = self.lstm(x, (h0, c0))
attn_output = self.attention(lstm_out, h_n[-1, :, :])
out = self.fc(attn_output)
return out
# Hyperparameters
input_size = 33 # Number of features
hidden_size = 128
num_layers = 4
output_size = 1
dropout_rate = 0.25
참고 (현재 상황: 데이터 업그레이드로 혼란 상황)
- 현재 캐글에서 test 데이터의 API를 업그레이들 한 상황이기 때문에 어제 일지에서 언급했던 'submission scoring error'라던지 'notebook timeout error'가 발생하고 있는 것 같습니다. 또한 기존에 그냥 올렸던 코드들도 오류로 인해 더 이상 제출이 안되고 있는 상황이라 대회의 디스커션을 보면 많은 사람들이 당황해하고 있는 거처럼 보였습니다. 뿐만 아니라 데이터 이슈로 인해 리더보드 상에서 순위 업그레이드도 잘 안되고 있는 거처럼 보입니다. 시간이 차츰 지나면 누군가 상황을 해결하겠지만 초보자입장인 저에게는 처음 겪는 상황이라 당황스럽습니다.
[마무리]
오늘은 최종 모델의 형태를 구축하였습니다. 그런데 캐글의 데이터 상의 업그레이드 이슈 때문에 제출이 불가해진 상황에서 결과를 도출해서 인사이트를 얻어야 프로젝트가 어느 정도 잘 마무리될 거 같은데 안되고 있어서 매우 당황스러운 상황입니다. 주말까지만 해결해서 결과를 얻을 수 있으면 좋을 거 같은 바람이 있습니다.
'AI > AI 부트캠프' 카테고리의 다른 글
[AI 부트캠프] DAY 85 - 딥러닝 프로젝트 10 (0) | 2023.11.21 |
---|---|
[AI 부트캠프] DAY 84 - 딥러닝 프로젝트 9 (0) | 2023.11.18 |
[AI 부트캠프] DAY 82 - 딥러닝 프로젝트 7 (0) | 2023.11.16 |
[AI 부트캠프] DAY 81 - 딥러닝 프로젝트 6 (1) | 2023.11.15 |
[AI 부트캠프] DAY 80 - 딥러닝 프로젝트 5 (0) | 2023.11.14 |
댓글