33 lines
765 B
Python
33 lines
765 B
Python
|
import sys
|
||
|
from enum import Enum
|
||
|
|
||
|
class NoiseType(Enum):
|
||
|
NONE = "none"
|
||
|
NORMAL = "normal"
|
||
|
UNIFORM = "uniform"
|
||
|
|
||
|
|
||
|
class ModelConfig:
|
||
|
num_agents: int
|
||
|
embedding_dim: int = 16
|
||
|
input_dim: int = 1
|
||
|
output_dim: int = 1
|
||
|
|
||
|
|
||
|
|
||
|
class TrainConfig:
|
||
|
"""Configuration for the training process."""
|
||
|
learning_rate: float = 1e-3
|
||
|
epochs: int = 100
|
||
|
batch_size: int = 32
|
||
|
verbose: bool = False # Set to True to see epoch loss during training
|
||
|
log: bool = True
|
||
|
log_epoch_interval: int = 10
|
||
|
noise_type: NoiseType = NoiseType.NONE
|
||
|
noise_level: float = 0.01
|
||
|
|
||
|
# --- New parameters for this script ---
|
||
|
data_directory:str
|
||
|
# Threshold for converting attention scores to a binary graph
|
||
|
f1_threshold: float = -0.5
|
||
|
|