#include <windows.h>
#include <process.h>
#include <shlobj.h>
#include <stdio.h>
#include <jni.h>

#include "jvm.h"
#include "dde_connect.h"

void CallMainMethod(void* arglist);
char** arg_split(char* buffer, int* p_argc);

void    InitResource(LPCTSTR name);
DWORD   GetResourceSize(LPCTSTR name);
jbyte*  GetResourceBuffer(LPCTSTR name);
jstring GetJString(const char* value);
LPSTR   GetShiftJIS(jstring src);
void    OutputError(LPCTSTR text);
void    OutputMessage(LPCTSTR text);
LPSTR	GetModuleMutexName();

void JNICALL JNI_UncaughtException(JNIEnv *env, jobject clazz, jstring message, jstring trace);

TCHAR  cache[] = "12345678901234567890123456789012";
jbyte* buffer = NULL;
DWORD  size   = 0;

jclass mainClass = NULL;
jmethodID mainMethod = NULL;

int main(int argc, char* argv[]) {
	LPTSTR ext_flags = NULL;
	if(GetResourceSize("EXTFLAGS") > 0) {
		ext_flags = _strupr((LPTSTR)GetResourceBuffer("EXTFLAGS"));
	}
	if(ext_flags != NULL && strstr(ext_flags, "DDE_CONNECT") != NULL) {
		if(dde_exec(CallMainMethod, GetModuleMutexName(), argc, argv) == DDE_CLIENT) {
			return 0;
		}
	}
	if(ext_flags != NULL && strstr(ext_flags, "SINGLE") != NULL) {
		if(CreateMutex(NULL, TRUE, GetModuleMutexName()), GetLastError() == ERROR_ALREADY_EXISTS) {
			return 0;
		}
	}
	
	LPTSTR vm_args_opt = NULL;
	if(GetResourceSize("VMARGS") > 0) {
		vm_args_opt = (LPTSTR)GetResourceBuffer("VMARGS");
	}
	if(CreateJavaVM(vm_args_opt) == NULL) {
		OutputMessage("JavaVM 쐬ł܂łB");
		return -1;
	};
	
	// FileLogStream
	jclass    fileLogStreamClass;
	jmethodID fileLogStreamInit;
	char      moduleFileName[MAX_PATH];
	if((fileLogStreamClass = env->DefineClass("FileLogStream", NULL, GetResourceBuffer("FILELOG"), GetResourceSize("FILELOG"))) != NULL) {
		if((fileLogStreamInit = env->GetStaticMethodID(fileLogStreamClass, "initialize", "(Ljava/lang/String;)V"))!= NULL) {
			GetModuleFileName(NULL, moduleFileName, MAX_PATH);
			env->CallStaticVoidMethod(fileLogStreamClass, fileLogStreamInit, GetJString(moduleFileName));
		}
	}
	// Toolkit
	jclass toolkitClass = env->DefineClass("Toolkit", NULL, GetResourceBuffer("TOOLKIT"), GetResourceSize("TOOLKIT"));
	if(toolkitClass == NULL) {
		OutputMessage("Fialed to define class: Toolkit\n");
		return -3;
	}
	// Loader
	jclass loaderClass = env->DefineClass("Loader", NULL, GetResourceBuffer("LOADER"), GetResourceSize("LOADER"));
	if(loaderClass == NULL) {
		OutputMessage("Fail to define class: Loader\n");
		return -4;
	}
	jmethodID loaderInit = env->GetMethodID(loaderClass, "<init>", "()V");
	if(loaderInit == NULL) {
		fputs("Fail to find method: Loader#init()\r\n", stderr);
		return -3;
	}
	jobject loader = env->NewObject(loaderClass, loaderInit);
	if(loader == NULL) {
		fputs("Fail to init: Loader\r\n", stderr);
		return -4;
	}
    jmethodID defineClass = env->GetMethodID(loaderClass, "defineClass", "([B)Ljava/lang/Class;");
	if(defineClass == NULL) {
		OutputMessage("Fail to find method: Loader#defineClass(byte[])\n");
        return -6;
    }
	// Main-Class
	jbyteArray jar = env->NewByteArray(GetResourceSize("JAR"));
	env->SetByteArrayRegion(jar, 0, GetResourceSize("JAR"), GetResourceBuffer("JAR"));
	
	mainClass = (jclass)(env->CallObjectMethod(loader, defineClass, jar));
	if(mainClass == NULL) {
		OutputError("CNX̃[hɎs܂B\r\n");
		return -8;
	}
	mainMethod = env->GetStaticMethodID(mainClass, "main", "([Ljava/lang/String;)V");
	if(mainMethod == NULL) {
		OutputMessage("main Ă܂B");
		return -9;
	}
	// UncaughtHandler (version 1.5.0 or higher only)
	jmethodID loaderGetRuntimeVersion = env->GetStaticMethodID(loaderClass, "getRuntimeVersion", "()I");
	if(loaderGetRuntimeVersion != NULL) {
		jint version = env->CallStaticIntMethod(loaderClass, loaderGetRuntimeVersion);
		if(version >= 150) {
			jclass    uncaughtHandlerClass;
			jmethodID uncaughtHandlerInit;
			if((uncaughtHandlerClass = env->DefineClass("UncaughtHandler", NULL, GetResourceBuffer("HANDLER"), GetResourceSize("HANDLER"))) != NULL) {
				if((uncaughtHandlerInit = env->GetStaticMethodID(uncaughtHandlerClass, "initialize", "()V"))!= NULL) {
					env->CallStaticVoidMethod(uncaughtHandlerClass, uncaughtHandlerInit);
					JNINativeMethod	nm;
					nm.name = "UncaughtException";
					nm.signature = "(Ljava/lang/String;Ljava/lang/String;)V";
					nm.fnPtr = (void*)JNI_UncaughtException;
					if((env->RegisterNatives(uncaughtHandlerClass, &nm, 1)) != 0) {
						fputs("Fail to regist native method: UncaughtException\n", stderr);
						return -6;
					}
				}
			}
		}
	}
	//
	jobjectArray args = env->NewObjectArray(argc - 1, env->FindClass("java/lang/String"), NULL);
	for(int i = 1; i < argc; i++) {
		env->SetObjectArrayElement(args, (i - 1), GetJString(argv[i]));
	}
	env->CallStaticVoidMethod(mainClass, mainMethod, args);
	
	if(env->ExceptionCheck() == JNI_TRUE) {
		OutputError(NULL);
		return -1;
	}
	DetachJavaVM();
	DestroyJavaVM();
}

void CallMainMethod(void* arglist) {
	int argc;
	char** argv = arg_split((char*)arglist, &argc);
	JNIEnv* env = AttachJavaVM();
	jobjectArray args = env->NewObjectArray(argc - 1, env->FindClass("java/lang/String"), NULL);
	for(int i = 1; i < argc; i++) {
		env->SetObjectArrayElement(args, (i - 1), GetJString(argv[i]));
	}
	env->CallStaticVoidMethod(mainClass, mainMethod, args);
	if(env->ExceptionCheck() == JNI_TRUE) {
		OutputError(NULL);
	}
	DetachJavaVM();
	free(argv);
	free(arglist);
	_endthread();
}

char** arg_split(char* buffer, int* p_argc) {
	*p_argc = 0;
	int i;
	for(i = 0; i < strlen(buffer); i++) {
		*p_argc += (buffer[i] == '\n')?1:0;
	}
	char** argv = (char**)calloc(sizeof(char*), *p_argc);
	for(i = 0; i < *p_argc; i++) {
		argv[i] = strtok(i?NULL:buffer, "\n");
	}
	return argv;
}

jstring GetJString(const char* src) {
    if(src == NULL) {
        return NULL;
    }
	int wSize = MultiByteToWideChar(CP_ACP, 0, src, strlen(src), NULL, 0);
	WCHAR wBuf[wSize];
	MultiByteToWideChar(CP_ACP, 0, src,	strlen(src), wBuf, wSize);
	return env->NewString((jchar*)wBuf, wSize);
}

LPSTR GetShiftJIS(jstring src) {
    if(src == NULL) {
        return NULL;
    }
    const jchar* unicode = env->GetStringChars(src, NULL);
    int length = wcslen((wchar_t*)unicode);
    LPSTR ret = (LPSTR)GlobalAlloc(GMEM_FIXED, sizeof(char) * length * 2 + 1);
    ZeroMemory(ret, sizeof(char) * length * 2 + 1);
    WideCharToMultiByte(CP_ACP, 0, (WCHAR*)unicode, length, ret, length * 2 + 1, NULL, NULL);
    env->ReleaseStringChars(src, unicode);
    return ret;
}

void InitResource(LPCTSTR name) {
	HRSRC hrsrc;
	
	if((hrsrc = FindResource(NULL, name, RT_RCDATA)) == NULL) {
		return;
	}
	size = SizeofResource(NULL, hrsrc);
	buffer = (jbyte*)LockResource(LoadResource(NULL, hrsrc));
	strcpy(cache, name);
}

DWORD GetResourceSize(LPCTSTR name) {
	if(strcmp(name, cache) != 0) {
		InitResource(name);
	}
	return size;
}

jbyte* GetResourceBuffer(LPCTSTR name) {
	if(strcmp(name, cache) != 0) {
		InitResource(name);
	}
	return buffer;
}

void OutputError(LPCTSTR text) {
	char buffer[2048];
	
	if(env->ExceptionCheck() != JNI_TRUE) {
		OutputMessage(text);
		return;
	}
	
	jthrowable throwable = env->ExceptionOccurred();
	env->ExceptionDescribe();
	env->ExceptionClear();
	LPSTR message = GetShiftJIS((jstring)env->CallObjectMethod(throwable,
		env->GetMethodID(env->FindClass("java/lang/Throwable"), "toString", "()Ljava/lang/String;"))
	);
	char* pos;
	if(text != NULL) {
		strcpy(buffer, text);
		strcat(buffer, "\r\n");
		strcat(buffer, message);
		OutputMessage(buffer);
	} else {
		OutputMessage(message);
	}		
	GlobalFree(message);
}

void OutputMessage(LPCTSTR text) {
	char* filename;
	char buffer[MAX_PATH];
	GetModuleFileName(NULL, buffer, MAX_PATH);
	filename = strrchr(buffer, '\\') + 1;
	
	MessageBox(NULL, text, filename, MB_ICONEXCLAMATION | MB_APPLMODAL | MB_OK | MB_SETFOREGROUND);
}

void JNICALL JNI_UncaughtException(JNIEnv *env, jobject clazz, jstring message, jstring trace) {
	OutputMessage(GetShiftJIS(message));
	exit(0);
}

LPSTR GetModuleMutexName() {
	LPSTR mutexName = (LPSTR)malloc(MAX_PATH + 32);
	LPSTR moduleFileName = (LPSTR)malloc(MAX_PATH);
	
	GetModuleFileName(NULL, moduleFileName, MAX_PATH);
	strcat(strcpy(mutexName, "EXEWRAP:MUTEX:"), lstrrchr(moduleFileName, '\\') + 1);
	
	free(moduleFileName);
	return mutexName;
}
