Skip to content
Snippets Groups Projects
Commit f02cb33b authored by Matteo Argenton's avatar Matteo Argenton
Browse files

adamw optimizer

parent bf705efe
No related branches found
Tags v0.1.0
No related merge requests found
......@@ -83,7 +83,7 @@ def test_step(model, state, graph, label):
def create_train_state(model, key, graph, lr):
params = model.init(key, graph)['params']
optimizer = optax.adam(learning_rate=lr)
optimizer = optax.adamw(learning_rate=lr)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
def model_setup(config, session_id=0):
......@@ -108,7 +108,7 @@ def metrics_setup(config, model, key, graph, orbax_ckpnt):
elif config['run_type'] == 'load_run':
try:
raw_restored = orbax_ckpnt.restore(config['ckpt_dir'])
state = train_state.TrainState.create(apply_fn=model.apply, params=raw_restored['model']['params'], tx=optax.adam(learning_rate=config['learning_r']))
state = train_state.TrainState.create(apply_fn=model.apply, params=raw_restored['model']['params'], tx=optax.adamw(learning_rate=config['learning_r']))
return (state,
raw_restored['tr_l'],
raw_restored['valid_l'],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment