Files
graph_recognition_w_attn/plot_results.py

80 lines
2.1 KiB
Python
Raw Normal View History

2025-08-04 12:44:35 -04:00
import os
import sys
import json
import pickle
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from config_ import ModelConfig
from train_and_eval import calculate_f1_score
from sklearn.metrics import f1_score
if len(sys.argv) < 2:
2025-09-01 14:46:34 -04:00
data_dir = "datasets/kuramoto_dataset"
2025-08-04 12:44:35 -04:00
else:
data_dir = "datasets/" + sys.argv[1]
datapoints = {}
THRESHOLD = 0.2
for folder in tqdm(os.listdir(data_dir)):
num_agents = int(folder.split("_")[1]) # Extract num agents
folder_path = os.path.join(data_dir, folder)
2025-09-01 14:46:34 -04:00
for noise_level in os.listdir(os.path.join(folder_path, "results/NoiseType.NONE")):
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
# Load model config from summary json
with open(os.path.join(folder_path, "results/NoiseType.NONE", noise_level, "summary_results.json"), "r") as f:
summary_results = json.load(f)
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
for i, graph in enumerate(os.listdir(folder_path)):
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
# train_summary_results
summ_results = summary_results[i-1]
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
if graph == "results": # ignore the result folder
continue
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
graph_path = os.path.join(folder_path, graph)
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
# Load run data
with open(os.path.join(folder_path, graph), "r") as f:
run_data = json.load(f)
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
true_graph = np.array(run_data["adjacency_matrix"])
learned_graph = np.array(summ_results["raw_attention"])
predicted_graph = (learned_graph > THRESHOLD).astype(int)
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
true_flat = true_graph.flatten()
pred_flat = predicted_graph.flatten()
calc_f1_score = f1_score(true_flat, pred_flat)
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
datapoints[num_agents] = datapoints.get(num_agents, [])
datapoints[num_agents].append(calc_f1_score)
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
for key in datapoints.keys():
try:
datapoints[key] = sum(datapoints[key])/len(datapoints[key])
except:
continue
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
x = []
y = []
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
for item in datapoints.items():
x.append(item[0])
y.append(item[1])
2025-08-04 12:44:35 -04:00
2025-09-01 14:46:34 -04:00
plt.plot(x, y)
plt.show()
2025-08-04 12:44:35 -04:00