training and plotting updates
This commit is contained in:
74
plot_results.py
Normal file
74
plot_results.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from config_ import ModelConfig
|
||||
from train_and_eval import calculate_f1_score
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
data_dir = "datasets/consensus_dataset"
|
||||
else:
|
||||
data_dir = "datasets/" + sys.argv[1]
|
||||
|
||||
datapoints = {}
|
||||
THRESHOLD = 0.2
|
||||
|
||||
for folder in tqdm(os.listdir(data_dir)):
|
||||
num_agents = int(folder.split("_")[1]) # Extract num agents
|
||||
|
||||
folder_path = os.path.join(data_dir, folder)
|
||||
|
||||
# Load model config from summary json
|
||||
with open(os.path.join(folder_path, "results/NoiseType.NONE", "summary_results.json"), "r") as f:
|
||||
summary_results = json.load(f)
|
||||
|
||||
|
||||
for i, graph in enumerate(os.listdir(folder_path)):
|
||||
|
||||
# train_summary_results
|
||||
summ_results = summary_results[i-1]
|
||||
|
||||
if graph == "results": # ignore the result folder
|
||||
continue
|
||||
|
||||
graph_path = os.path.join(folder_path, graph)
|
||||
|
||||
# Load run data
|
||||
with open(os.path.join(folder_path, graph), "r") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
true_graph = np.array(run_data["adjacency_matrix"])
|
||||
|
||||
learned_graph = np.array(summ_results["raw_attention"])
|
||||
|
||||
predicted_graph = (learned_graph > THRESHOLD).astype(int)
|
||||
|
||||
true_flat = true_graph.flatten()
|
||||
pred_flat = predicted_graph.flatten()
|
||||
|
||||
calc_f1_score = f1_score(true_flat, pred_flat)
|
||||
|
||||
|
||||
datapoints[num_agents] = datapoints.get(num_agents, [])
|
||||
datapoints[num_agents].append(calc_f1_score)
|
||||
|
||||
|
||||
for key in datapoints.keys():
|
||||
datapoints[key] = sum(datapoints[key])/len(datapoints[key])
|
||||
|
||||
|
||||
x = []
|
||||
y = []
|
||||
|
||||
for item in datapoints.items():
|
||||
x.append(item[0])
|
||||
y.append(item[1])
|
||||
|
||||
plt.plot(x, y)
|
||||
plt.show()
|
||||
|
||||
|
Reference in New Issue
Block a user