Commit 1776571d authored by sjmonagi's avatar sjmonagi
Browse files

images only

parent 81d49779
...@@ -10,10 +10,13 @@ from autoencoder import load_autoencoder ...@@ -10,10 +10,13 @@ from autoencoder import load_autoencoder
from experience_buffer import experience_buffer from experience_buffer import experience_buffer
from helper import plotting_training_log, train_valid_env_sync, validate from helper import plotting_training_log, train_valid_env_sync, validate
random.seed(123) random.seed(123)
np.random.seed(123) np.random.seed(123)
dir = "/home/nagi/Desktop/Master_project_final/DRQN_3_her_sparse_ More_sequences_More_/DRQN.ckpt" fields_name = ["iteration", "successes"]
dir = "/home/nagi/Desktop/Master_project_final/DRQN_3_her_shaped_image_and_pos/DRQN.ckpt"
##### environment_Variables ##### environment_Variables
grid_size = 0.18 # size of the agent step grid_size = 0.18 # size of the agent step
...@@ -22,11 +25,11 @@ distance_threshold = grid_size * 2 # distance threshold to the goal ...@@ -22,11 +25,11 @@ distance_threshold = grid_size * 2 # distance threshold to the goal
action_n = 3 # number of allowed action action_n = 3 # number of allowed action
random_init_position = False # Random initial positions only -- no change in the agent orientation random_init_position = False # Random initial positions only -- no change in the agent orientation
random_init_pose = True # Random initial positions with random agent orientation random_init_pose = True # Random initial positions with random agent orientation
reward = "sparse" # reward type "shaped","sparse" reward = "shaped" # reward type "shaped","sparse"
######################### hyper-parameter ######################### hyper-parameter
num_episodes = 15001 num_episodes = 15001
her_samples = 64 her_samples = 8
batch_size = 32 batch_size = 32
trace_length = 8 trace_length = 8
gamma = 0.99 gamma = 0.99
...@@ -34,9 +37,9 @@ fcl_dims = 512 ...@@ -34,9 +37,9 @@ fcl_dims = 512
nodes_num = 256 nodes_num = 256
optimistion_steps = 40 optimistion_steps = 40
epsilon_max = 1 epsilon_max = 1
epsilon_min = 0.05 epsilon_min = 0
input_size = 521 ## size of the input to the LSTM input_size = 521 ## size of the input to the LSTM
epsilon_decay = epsilon_max - ((epsilon_max - epsilon_min) / 3500) epsilon_decay = epsilon_max - (epsilon_max / 3500)
## pandas data-frame for plotting ## pandas data-frame for plotting
...@@ -86,6 +89,7 @@ with tf.Session() as sess: ...@@ -86,6 +89,7 @@ with tf.Session() as sess:
epsilon = 1 epsilon = 1
for n in range(start, num_episodes): for n in range(start, num_episodes):
step_num = 0
# rnn_init_state # rnn_init_state
rnn_state = (np.zeros([1, nodes_num]), np.zeros([1, nodes_num])) rnn_state = (np.zeros([1, nodes_num]), np.zeros([1, nodes_num]))
# reset environment # reset environment
...@@ -139,13 +143,17 @@ with tf.Session() as sess: ...@@ -139,13 +143,17 @@ with tf.Session() as sess:
obs_pos_state = obs_pos_state_ obs_pos_state = obs_pos_state_
distance = distance_ distance = distance_
pre_action_idx = curr_action_idx pre_action_idx = curr_action_idx
step_num += 1
if done: if done:
if distance < distance_threshold: if distance < distance_threshold:
successes += done successes += done
else: else:
failures += done failures += done
break break
if step_num == 200:
done = True
failures += done
break
her_buffer = episode_buffer.her() her_buffer = episode_buffer.her()
her_rec_buffer.add(her_buffer) her_rec_buffer.add(her_buffer)
...@@ -157,10 +165,11 @@ with tf.Session() as sess: ...@@ -157,10 +165,11 @@ with tf.Session() as sess:
"Ratio": (successes / (failures + 1e-6)), "Ratio": (successes / (failures + 1e-6)),
"loss": loss, "epsilon": epsilon}, ignore_index=True) "loss": loss, "epsilon": epsilon}, ignore_index=True)
plotting_training_log(n, plotted_data, successes, failures, loss, goal, distance, pos_state, epsilon)
plotting_training_log(n, plotted_data, successes, failures, loss, goal, distance, pos_state, epsilon, step_num)
###validation### ###validation###
if n % 5000 == 0 and n > 0: if n % 2000 == 0 and n > 0:
validate(n=n, nodes_num=nodes_num, top_view=top_view, env=env, envT=envT, ae=ae, ae_sess=ae_sess, validate(n=n, nodes_num=nodes_num, top_view=top_view, env=env, envT=envT, ae=ae, ae_sess=ae_sess,
distance_threshold=distance_threshold, model=model) distance_threshold=distance_threshold, model=model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment