Updated data generation code

This commit is contained in:
2025-09-01 14:46:34 -04:00
parent d998f6de4c
commit e018238935
14 changed files with 709 additions and 123 deletions

View File

@@ -9,7 +9,7 @@ class NoiseType(Enum):
class ModelConfig:
num_agents: int
embedding_dim: int = 16
embedding_dim: int = 64
input_dim: int = 1
output_dim: int = 1
@@ -19,15 +19,12 @@ class TrainConfig:
"""Configuration for the training process."""
learning_rate: float = 1e-3
epochs: int = 100
batch_size: int = 32
verbose: bool = False # Set to True to see epoch loss during training
log: bool = True
log_epoch_interval: int = 10
noise_type: NoiseType = NoiseType.NONE
noise_level: float = 0.01
# --- New parameters for this script ---
batch_size: int = 4096 * 4
verbose: bool = False # See loss during training
log: bool = True # Logs the attention every log interval
log_epoch_interval: int = 10 # How often you want to log intervals
noise_type: NoiseType = NoiseType.NONE # What kind of noise you want
noise_level: float = 0.00 # Stddev for Normal, half-width for Uniform
data_directory:str
# Threshold for converting attention scores to a binary graph
f1_threshold: float = -0.5
f1_threshold: float = -0.5 # Threshold for binary classification