training and plotting updates
This commit is contained in:
33
config_.py
Normal file
33
config_.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
class NoiseType(Enum):
|
||||
NONE = "none"
|
||||
NORMAL = "normal"
|
||||
UNIFORM = "uniform"
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
num_agents: int
|
||||
embedding_dim: int = 16
|
||||
input_dim: int = 1
|
||||
output_dim: int = 1
|
||||
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
"""Configuration for the training process."""
|
||||
learning_rate: float = 1e-3
|
||||
epochs: int = 100
|
||||
batch_size: int = 32
|
||||
verbose: bool = False # Set to True to see epoch loss during training
|
||||
log: bool = True
|
||||
log_epoch_interval: int = 10
|
||||
noise_type: NoiseType = NoiseType.NONE
|
||||
noise_level: float = 0.01
|
||||
|
||||
# --- New parameters for this script ---
|
||||
data_directory:str
|
||||
# Threshold for converting attention scores to a binary graph
|
||||
f1_threshold: float = -0.5
|
||||
|
Reference in New Issue
Block a user