Updated data generation code
This commit is contained in:
19
config_.py
19
config_.py
@@ -9,7 +9,7 @@ class NoiseType(Enum):
|
||||
|
||||
class ModelConfig:
|
||||
num_agents: int
|
||||
embedding_dim: int = 16
|
||||
embedding_dim: int = 64
|
||||
input_dim: int = 1
|
||||
output_dim: int = 1
|
||||
|
||||
@@ -19,15 +19,12 @@ class TrainConfig:
|
||||
"""Configuration for the training process."""
|
||||
learning_rate: float = 1e-3
|
||||
epochs: int = 100
|
||||
batch_size: int = 32
|
||||
verbose: bool = False # Set to True to see epoch loss during training
|
||||
log: bool = True
|
||||
log_epoch_interval: int = 10
|
||||
noise_type: NoiseType = NoiseType.NONE
|
||||
noise_level: float = 0.01
|
||||
|
||||
# --- New parameters for this script ---
|
||||
batch_size: int = 4096 * 4
|
||||
verbose: bool = False # See loss during training
|
||||
log: bool = True # Logs the attention every log interval
|
||||
log_epoch_interval: int = 10 # How often you want to log intervals
|
||||
noise_type: NoiseType = NoiseType.NONE # What kind of noise you want
|
||||
noise_level: float = 0.00 # Stddev for Normal, half-width for Uniform
|
||||
data_directory:str
|
||||
# Threshold for converting attention scores to a binary graph
|
||||
f1_threshold: float = -0.5
|
||||
f1_threshold: float = -0.5 # Threshold for binary classification
|
||||
|
Reference in New Issue
Block a user