2025-08-04 12:44:35 -04:00
|
|
|
import sys
|
|
|
|
from enum import Enum
|
|
|
|
|
|
|
|
class NoiseType(Enum):
|
|
|
|
NONE = "none"
|
|
|
|
NORMAL = "normal"
|
|
|
|
UNIFORM = "uniform"
|
|
|
|
|
|
|
|
|
|
|
|
class ModelConfig:
|
|
|
|
num_agents: int
|
2025-09-01 14:46:34 -04:00
|
|
|
embedding_dim: int = 64
|
2025-08-04 12:44:35 -04:00
|
|
|
input_dim: int = 1
|
|
|
|
output_dim: int = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainConfig:
|
|
|
|
"""Configuration for the training process."""
|
|
|
|
learning_rate: float = 1e-3
|
|
|
|
epochs: int = 100
|
2025-09-01 14:46:34 -04:00
|
|
|
batch_size: int = 4096 * 4
|
|
|
|
verbose: bool = False # See loss during training
|
|
|
|
log: bool = True # Logs the attention every log interval
|
|
|
|
log_epoch_interval: int = 10 # How often you want to log intervals
|
|
|
|
noise_type: NoiseType = NoiseType.NONE # What kind of noise you want
|
|
|
|
noise_level: float = 0.00 # Stddev for Normal, half-width for Uniform
|
2025-08-04 12:44:35 -04:00
|
|
|
data_directory:str
|
2025-09-01 14:46:34 -04:00
|
|
|
f1_threshold: float = -0.5 # Threshold for binary classification
|
2025-08-04 12:44:35 -04:00
|
|
|
|