training and plotting updates

This commit is contained in:
2025-08-04 12:44:35 -04:00
parent 5a0c479c1e
commit d998f6de4c
9 changed files with 457 additions and 71 deletions

33
config_.py Normal file
View File

@@ -0,0 +1,33 @@
import sys
from enum import Enum
class NoiseType(Enum):
NONE = "none"
NORMAL = "normal"
UNIFORM = "uniform"
class ModelConfig:
num_agents: int
embedding_dim: int = 16
input_dim: int = 1
output_dim: int = 1
class TrainConfig:
"""Configuration for the training process."""
learning_rate: float = 1e-3
epochs: int = 100
batch_size: int = 32
verbose: bool = False # Set to True to see epoch loss during training
log: bool = True
log_epoch_interval: int = 10
noise_type: NoiseType = NoiseType.NONE
noise_level: float = 0.01
# --- New parameters for this script ---
data_directory:str
# Threshold for converting attention scores to a binary graph
f1_threshold: float = -0.5