33 lines
765 B
Python
33 lines
765 B
Python
import sys
|
|
from enum import Enum
|
|
|
|
class NoiseType(Enum):
|
|
NONE = "none"
|
|
NORMAL = "normal"
|
|
UNIFORM = "uniform"
|
|
|
|
|
|
class ModelConfig:
|
|
num_agents: int
|
|
embedding_dim: int = 16
|
|
input_dim: int = 1
|
|
output_dim: int = 1
|
|
|
|
|
|
|
|
class TrainConfig:
|
|
"""Configuration for the training process."""
|
|
learning_rate: float = 1e-3
|
|
epochs: int = 100
|
|
batch_size: int = 32
|
|
verbose: bool = False # Set to True to see epoch loss during training
|
|
log: bool = True
|
|
log_epoch_interval: int = 10
|
|
noise_type: NoiseType = NoiseType.NONE
|
|
noise_level: float = 0.01
|
|
|
|
# --- New parameters for this script ---
|
|
data_directory:str
|
|
# Threshold for converting attention scores to a binary graph
|
|
f1_threshold: float = -0.5
|
|
|