import json
import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import trt_pose.coco
import math
import os
import numpy as np
import traitlets
import pickle
import trt_pose.models
import torch
from torch2trt import TRTModule
from trt_pose.draw_objects import DrawObjects
from trt_pose.parse_objects import ParseObjects
import torchvision.transforms as transforms
import PIL.Image
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from preprocessdata import preprocessdata
from dataloader import dataloader
from jetcam.csi_camera import CSICamera
from jetcam.utils import bgr8_to_jpeg
from sklearn.model_selection import GridSearchCV

with open('preprocess/hand_pose.json', 'r') as f:
    hand_pose = json.load(f)

topology = trt_pose.coco.coco_category_to_topology(hand_pose) # a tensor
num_parts = len(hand_pose['keypoints'])
num_links = len(hand_pose['skeleton'])

model = trt_pose.models.resnet18_baseline_att(num_parts, 2*num_links).cuda().eval()

#these widths and heights are sort of required - this program errors out at higher values
WIDTH = 224
HEIGHT = 224
data = torch.zeros((1, 3 , HEIGHT, WIDTH)).cuda()

#generating an optimized model if one does not already exist
if not os.path.exists('/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244_trt.pth'):
    MODEL_WEIGHTS = '/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244.pth'
    model.load_state_dict(torch.load(MODEL_WEIGHTS))
    import torch2trt
    model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)
    OPTIMIZED_MODEL = '/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244_trt.pth'
    torch.save(model_trt.state_dict(), OPTIMIZED_MODEL)

OPTIMIZED_MODEL = '/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244_trt.pth'

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))

parse_objects = ParseObjects(topology,cmap_threshold=0.12, link_threshold=0.15)
draw_objects = DrawObjects(topology)

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
device = torch.device('cuda')

with open('preprocess/gesture.json', 'r') as f: # might need to be made in gesture collection file
    gesture = json.load(f)
gesture_type = gesture['classes']
print('gesture type: ' + str(gesture_type))

def draw_joints(image, joints):
    count = 0
    for i in joints:
        if i==[0,0]:
            count+=1
    if count>= 3:
        return 
    for i in joints:
        cv2.circle(image, (i[0],i[1]), 2, (0,0,255), 1)
    cv2.circle(image, (joints[0][0],joints[0][1]), 2, (255,0,255), 1)
    for i in hand_pose['skeleton']:
        if joints[i[0]-1][0]==0 or joints[i[1]-1][0] == 0:
            break
        cv2.line(image, (joints[i[0]-1][0],joints[i[0]-1][1]), (joints[i[1]-1][0],joints[i[1]-1][1]), (0,255,0), 1)

#converts image into tensor to supply into model
def preprocess(image):
    global device
    device = torch.device('cuda')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device)
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

#clf = make_pipeline(StandardScaler(), SVC(C = 0.1, gamma=0.1, kernel='poly'))
#linear_clf = make_pipeline(StandardScaler(), LinearSVC(dual=False, ))

#grid search for hyperparameter tuning
#param_grid = {'C': [0.1, 1, 10], 'gamma': [0.1, 0.01, 0.001], 'kernel': ['rbf', 'sigmoid', 'poly']}
#grid = GridSearchCV(SVC(),  param_grid, refit=True, verbose=3)

preprocessdata = preprocessdata(topology, num_parts)

#--------------------------------------------------------------------------
#change path and files here dependent on datasets
path = "/home/articulight/Desktop/gesture_datasets/large/"
label_file = "large.json"
test_label = "large_testing.json"

def data_preprocess(images):
    dist_bn_joints_all_data = []
    for im in images:
        im = im[:, ::-1, :]
        data_im = preprocess(im)
        cmap, paf = model_trt(data_im)
        cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
        counts, objects, peaks = parse_objects(cmap, paf)
        joints = preprocessdata.joints_inference(im, counts, objects, peaks)
        dist_bn_joints = preprocessdata.find_distance(joints)
        dist_bn_joints_all_data.append(dist_bn_joints)
    return dist_bn_joints_all_data

#train_images, labels_train = hand.smaller_dataset(hand.train_images,100,6)

#grid.fit(joints_train, hand.labels_train)
#print(grid.best_params_)


svm_train = True
if svm_train:
    #SVM Pipeline - read sklearn make_pipeline
    clf = make_pipeline(StandardScaler(), SVC(gamma='auto', kernel='rbf')) #default
    hand = dataloader(path, label_file, test_label)
    joints_train = data_preprocess(hand.train_images)
    joints_test = data_preprocess(hand.test_images)
    clf, predicted = preprocessdata.trainsvm(clf, joints_train, joints_test, hand.labels_train, hand.labels_test)
    filename = 'large.sav'
    pickle.dump(clf, open(filename, 'wb'))
else:
    filename = 'large.sav'
    clf = pickle.load(open(filename, 'rb'))

"""
#acc = preprocessdata.svm_accuracy(clf.predict(joints_test), hand.labels_test)

#pred = clf.predict([joints_test[40],[0]*num_parts*num_parts])
#pred2 = clf.predict(joints_test)
"""
camera = CSICamera(width=WIDTH, height=HEIGHT, capture_fps=7)

while(1):
    try:
        image = camera.read()
        image = cv2.rotate(image, cv2.ROTATE_180)

        data = preprocess(image)
        cmap, paf = model_trt(data)
        cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
        counts, objects, peaks = parse_objects(cmap, paf)
        draw_objects(image, counts, objects, peaks)
        joints = preprocessdata.joints_inference(image, counts, objects, peaks)
        #draw_joints(image, joints)
        dist_bn_joints = preprocessdata.find_distance(joints)
        gesture = clf.predict([dist_bn_joints, [0]*num_parts*num_parts])
        print(gesture)
        gesture_joints = gesture[0]
        #print("supposed index joint:" + str(joints[6][:]))
        preprocessdata.prev_queue.append(gesture_joints)
        preprocessdata.prev_queue.pop(0)
        preprocessdata.print_label(image, preprocessdata.prev_queue, gesture_type)
        #image = bgr8_to_jpeg(image[:, ::-1, :])
        cv2.imshow('camera feed', image)
        cv2.waitKey(1)

    except KeyboardInterrupt:
        break