diff --git a/webclient.c b/webclient.c index 3bb669e..4a500c2 100644 --- a/webclient.c +++ b/webclient.c @@ -48,6 +48,33 @@ char *webclient_strdup(const char *s) return tmp; } +static int webclient_send(struct webclient_session* session, const unsigned char *buffer, size_t len, int flag) +{ + if (!session) + return -RT_ERROR; + +#ifdef PKG_USING_WEBCLIENT_TLS + if(session->tls_session) + return mbedtls_client_write(session->tls_session, buffer, len); +#endif + + return send(session->socket, buffer, len, flag); +} + +static int webclient_recv(struct webclient_session* session, unsigned char *buffer, size_t len, int flag) +{ + if (!session) + return -RT_ERROR; + +#ifdef PKG_USING_WEBCLIENT_TLS + if(session->tls_session) + return mbedtls_client_read(session->tls_session, buffer, len); +#endif + + return recv(session->socket, buffer, len, flag); +} + + static char *webclient_header_skip_prefix(char *line, const char *prefix) { char *ptr; @@ -78,7 +105,7 @@ static char *webclient_header_skip_prefix(char *line, const char *prefix) * before the data. We need to read exactly to the end of the headers * and no more data. This readline reads a single char at a time. */ -static int webclient_read_line(int socket, char *buffer, int size) +static int webclient_read_line(struct webclient_session* session, char *buffer, int size) { int rc; char *ptr = buffer; @@ -87,7 +114,11 @@ static int webclient_read_line(int socket, char *buffer, int size) /* Keep reading until we fill the buffer. */ while (count < size) { - rc = recv(socket, ptr, 1, 0); + rc = webclient_recv(session, (unsigned char *)ptr, 1, 0); +#ifdef PKG_USING_WEBCLIENT_TLS + if(session->tls_session && rc == MBEDTLS_ERR_SSL_WANT_READ) + continue; +#endif if (rc <= 0) return rc; @@ -132,19 +163,27 @@ static int webclient_resolve_address(struct webclient_session *session, struct a int rc = WEBCLIENT_OK; char *ptr; char port_str[6] = "80"; /* default port of 80(http) */ + char port_tls_str[6] = "443"; /* default port of 443(https) */ const char *host_addr = 0; int url_len, host_addr_len = 0; url_len = strlen(url); - /* strip protocol(http) */ - if (strncmp(url, "http://", 7) != 0) + /* strip protocol(http or https) */ + if (strncmp(url, "http://", 7) == 0) + { + host_addr = url + 7; + } + else if(strncmp(url, "https://", 8) == 0) + { + host_addr = url + 8; + } + else { rc = -1; - goto _exit; + goto _exit; } - host_addr = url + 7; /* ipv6 address */ if (host_addr[0] == '[') @@ -190,7 +229,37 @@ static int webclient_resolve_address(struct webclient_session *session, struct a } host_addr_len = ptr - host_addr; *request = (char *)ptr; + +#ifdef PKG_USING_WEBCLIENT_TLS + char *port_tls_ptr; + + if(session->tls_session) + { + port_tls_ptr = strstr(host_addr, ":"); + if (port_tls_ptr) + { + int port_tls_len = ptr - port_tls_ptr - 1; + strncpy(port_tls_str, port_tls_ptr + 1, port_tls_len); + port_str[port_tls_len] = '\0'; + + host_addr_len = port_tls_ptr - host_addr; + } + } + else + { + port_ptr = strstr(host_addr, ":"); + if (port_ptr) + { + int port_len = ptr - port_ptr - 1; + + strncpy(port_str, port_ptr + 1, port_len); + port_str[port_len] = '\0'; + + host_addr_len = port_ptr - host_addr; + } + } +#else port_ptr = strstr(host_addr, ":"); if (port_ptr) { @@ -201,6 +270,7 @@ static int webclient_resolve_address(struct webclient_session *session, struct a host_addr_len = port_ptr - host_addr; } +#endif } if ((host_addr_len < 1) || (host_addr_len > url_len)) @@ -223,7 +293,11 @@ static int webclient_resolve_address(struct webclient_session *session, struct a memcpy(host_addr_new, host_addr, host_addr_len); host_addr_new[host_addr_len] = '\0'; session->host = host_addr_new; - //rt_kprintf("session->host: %s\n", session->host); + +#ifdef PKG_USING_WEBCLIENT_TLS + if(session->tls_session) + session->tls_session->host = rt_strdup(host_addr_new); +#endif } { @@ -232,6 +306,30 @@ static int webclient_resolve_address(struct webclient_session *session, struct a int ret; memset(&hint, 0, sizeof(hint)); + +#ifdef PKG_USING_WEBCLIENT_TLS + if(session->tls_session) + { + session->tls_session->port = rt_strdup(port_tls_str); + ret = getaddrinfo(session->tls_session->host, port_tls_str, &hint, res); + if (ret != 0) + { + rt_kprintf("getaddrinfo err: %d '%s'\n", ret, session->host); + rc = -1; + goto _exit; + } + } + else + { + ret = getaddrinfo(session->host, port_str, &hint, res); + if (ret != 0) + { + rt_kprintf("getaddrinfo err: %d '%s'\n", ret, session->host); + rc = -1; + goto _exit; + } + } +#else ret = getaddrinfo(session->host, port_str, &hint, res); if (ret != 0) { @@ -239,8 +337,8 @@ static int webclient_resolve_address(struct webclient_session *session, struct a rc = -1; goto _exit; } +#endif } - _exit: if (rc != WEBCLIENT_OK) { @@ -387,7 +485,7 @@ int webclient_handle_response(struct webclient_session *session) int i; /* read a line from the header information. */ - rc = webclient_read_line(session->socket, mimeBuffer, WEBCLIENT_RESPONSE_BUFSZ); + rc = webclient_read_line(session, mimeBuffer, WEBCLIENT_RESPONSE_BUFSZ); if (rc < 0) break; @@ -469,7 +567,7 @@ int webclient_handle_response(struct webclient_session *session) && strcmp(session->transfer_encoding, "chunked") == 0) { /* chunk mode, we should get the first chunk size */ - webclient_read_line(session->socket, mimeBuffer, WEBCLIENT_RESPONSE_BUFSZ); + webclient_read_line(session, mimeBuffer, WEBCLIENT_RESPONSE_BUFSZ); session->chunk_sz = strtol(mimeBuffer, RT_NULL, 16); session->chunk_offset = 0; } @@ -521,28 +619,62 @@ int webclient_connect(struct webclient_session *session, const char *URI) else session->request = RT_NULL; - socket_handle = socket(res->ai_family, SOCK_STREAM, IPPROTO_TCP); // - if (socket_handle < 0) +#ifdef PKG_USING_WEBCLIENT_TLS + if(session->tls_session) { - rc = -WEBCLIENT_NOSOCKET; + int tls_ret = 0; + + if((tls_ret = mbedtls_client_context(session->tls_session)) < 0) + { + rt_kprintf("webclient mbedtls_client_context err return : -0x%x\n", -tls_ret); + return -RT_ERROR; + } + + if((tls_ret = mbedtls_client_connect(session->tls_session)) < 0) + { + rt_kprintf("webclient mbedtls_client_connect err return : -0x%x\n", -tls_ret); + rc = -WEBCLIENT_CONNECT_FAILED; + goto _exit; + } + + socket_handle = session->tls_session->server_fd.fd; + + /* set recv timeout option */ + setsockopt(socket_handle, SOL_SOCKET, SO_RCVTIMEO, (void*) &timeout, + sizeof(timeout)); + setsockopt(socket_handle, SOL_SOCKET, SO_SNDTIMEO, (void*) &timeout, + sizeof(timeout)); + + session->socket = socket_handle; + rc = WEBCLIENT_OK; goto _exit; } +#endif - /* set recv timeout option */ - setsockopt(socket_handle, SOL_SOCKET, SO_RCVTIMEO, (void *) &timeout, - sizeof(timeout)); - setsockopt(socket_handle, SOL_SOCKET, SO_SNDTIMEO, (void *) &timeout, - sizeof(timeout)); + { + socket_handle = socket(res->ai_family, SOCK_STREAM, IPPROTO_TCP); // + if (socket_handle < 0) + { + rc = -WEBCLIENT_NOSOCKET; + goto _exit; + } - if (connect(socket_handle, res->ai_addr, res->ai_addrlen) != 0) - { - /* connect failed, close socket handle */ - closesocket(socket_handle); - rc = -WEBCLIENT_CONNECT_FAILED; - goto _exit; - } + /* set recv timeout option */ + setsockopt(socket_handle, SOL_SOCKET, SO_RCVTIMEO, (void *) &timeout, + sizeof(timeout)); + setsockopt(socket_handle, SOL_SOCKET, SO_SNDTIMEO, (void *) &timeout, + sizeof(timeout)); + + if (connect(socket_handle, res->ai_addr, res->ai_addrlen) != 0) + { + /* connect failed, close socket handle */ + closesocket(socket_handle); + rc = -WEBCLIENT_CONNECT_FAILED; + goto _exit; + } - session->socket = socket_handle; + session->socket = socket_handle; + } _exit: if (res) @@ -553,6 +685,42 @@ _exit: return rc; } +int webclient_open_tls(struct webclient_session * session, const char *URI) +{ +#ifdef PKG_USING_WEBCLIENT_TLS + int tls_ret = 0; + const char *pers = "wenclient"; + + if(!session) + return -RT_ERROR; + + session->tls_session = (MbedTLSSession *)web_malloc(sizeof(MbedTLSSession)); + if (session->tls_session == RT_NULL) + return -RT_ERROR; + memset(session->tls_session, 0x0, sizeof(MbedTLSSession)); + + session->tls_session->buffer_len = WEBCLIENT_TLS_READ_BUFFER; + session->tls_session->buffer = web_malloc(session->tls_session->buffer_len); + if(session->tls_session->buffer == RT_NULL) + { + rt_kprintf("no memory for webclient tls_session buffer malloc\n"); + return -RT_ERROR; + } + + if((tls_ret = mbedtls_client_init(session->tls_session, (void *)pers, strlen(pers))) < 0) + { + rt_kprintf("webclient mbedtls_client_init err return : -0x%x\n", -tls_ret); + return -RT_ERROR; + } + + return RT_EOK; +#else + rt_kprintf("don't support TLS protocol, check your menuconfig!\n"); + return -RT_ERROR; + +#endif +} + struct webclient_session *webclient_open(const char *URI) { struct webclient_session *session; @@ -562,6 +730,16 @@ struct webclient_session *webclient_open(const char *URI) if (session == RT_NULL) return RT_NULL; memset(session, 0x0, sizeof(struct webclient_session)); + session->socket = -1; + + if(strncmp(URI, "https://", 8) == 0) + { + if(webclient_open_tls(session, URI) < 0) + { + webclient_close(session); + return RT_NULL; + } + } if (webclient_connect(session, URI) < 0) { @@ -611,6 +789,15 @@ struct webclient_session *webclient_open_position(const char *URI, int position) return RT_NULL; memset(session, 0x0, sizeof(struct webclient_session)); + if(strncmp(URI, "https://", 8) == 0) + { + if(webclient_open_tls(session, URI) < 0) + { + webclient_close(session); + return RT_NULL; + } + } + if (webclient_connect(session, URI) < 0) { /* connect to webclient server failed. */ @@ -671,6 +858,15 @@ struct webclient_session *webclient_open_header(const char *URI, int method, return RT_NULL; memset(session, 0, sizeof(struct webclient_session)); + if(strncmp(URI, "https://", 8) == 0) + { + if(webclient_open_tls(session, URI) < 0) + { + webclient_close(session); + return RT_NULL; + } + } + if (webclient_connect(session, URI) < 0) { /* connect to webclient server failed. */ @@ -715,12 +911,12 @@ static int webclient_next_chunk(struct webclient_session *session) char line[64]; int length; - length = webclient_read_line(session->socket, line, sizeof(line)); + length = webclient_read_line(session, line, sizeof(line)); if (length) { if (strcmp(line, "\r\n") == 0) { - length = webclient_read_line(session->socket, line, sizeof(line)); + length = webclient_read_line(session, line, sizeof(line)); if (length <= 0) { closesocket(session->socket); @@ -766,7 +962,7 @@ int webclient_read(struct webclient_session *session, unsigned char *buffer, if (length > (session->chunk_sz - session->chunk_offset)) length = session->chunk_sz - session->chunk_offset; - bytesRead = recv(session->socket, buffer, length, 0); + bytesRead = webclient_recv(session, buffer, length, 0); if (bytesRead <= 0) { if (errno == EWOULDBLOCK || errno == EAGAIN) @@ -811,9 +1007,13 @@ int webclient_read(struct webclient_session *session, unsigned char *buffer, left = length; do { - bytesRead = recv(session->socket, buffer + totalRead, left, 0); + bytesRead = webclient_recv(session, buffer + totalRead, left, 0); if (bytesRead <= 0) { +#ifdef PKG_USING_WEBCLIENT_TLS + if(session->tls_session && bytesRead == MBEDTLS_ERR_SSL_WANT_READ) + continue; +#endif rt_kprintf("errno=%d\n", bytesRead); if (totalRead) @@ -868,9 +1068,13 @@ int webclient_write(struct webclient_session *session, */ do { - bytesWrite = send(session->socket, buffer + totalWrite, left, 0); + bytesWrite = webclient_send(session, buffer + totalWrite, left, 0); if (bytesWrite <= 0) { +#ifdef PKG_USING_WEBCLIENT_TLS + if(session->tls_session && bytesWrite == MBEDTLS_ERR_SSL_WANT_WRITE) + continue; +#endif if (errno == EWOULDBLOCK || errno == EAGAIN) { /* send timeout */ @@ -905,15 +1109,25 @@ int webclient_write(struct webclient_session *session, int webclient_close(struct webclient_session *session) { RT_ASSERT(session != RT_NULL); - + +#ifdef PKG_USING_WEBCLIENT_TLS + if(session->tls_session) + mbedtls_client_close(session->tls_session); +#endif if (session->socket >= 0) - closesocket(session->socket); - web_free(session->transfer_encoding); - web_free(session->content_type); - web_free(session->last_modified); - web_free(session->host); - web_free(session->request); - web_free(session); + closesocket(session->socket); + if(session->transfer_encoding) + web_free(session->transfer_encoding); + if(session->content_type) + web_free(session->content_type); + if(session->last_modified) + web_free(session->last_modified); + if(session->host) + web_free(session->host); + if(session->request) + web_free(session->request); + if(session) + web_free(session); return 0; } diff --git a/webclient.h b/webclient.h index f653f4b..0190ab9 100644 --- a/webclient.h +++ b/webclient.h @@ -18,8 +18,13 @@ #include +#ifdef PKG_USING_WEBCLIENT_TLS +#include +#endif + #define WEBCLIENT_HEADER_BUFSZ 4096 #define WEBCLIENT_RESPONSE_BUFSZ 4096 +#define WEBCLIENT_TLS_READ_BUFFER 4096 //typedef unsigned int size_t; @@ -77,6 +82,11 @@ struct webclient_session /* remainder of content reading */ size_t content_length_remainder; + +#ifdef PKG_USING_WEBCLIENT_TLS + /* mbedtls session struct*/ + MbedTLSSession *tls_session; +#endif }; struct webclient_session *webclient_open(const char *URI); diff --git a/webclient_file.c b/webclient_file.c index 970ff4c..c7b6565 100644 --- a/webclient_file.c +++ b/webclient_file.c @@ -131,6 +131,14 @@ int webclient_post_file(const char* URI, const char* filename, } memset(session, 0x0, sizeof(struct webclient_session)); + if(strncmp(URI, "https://", 8) == 0) + { + if(webclient_open_tls(session, URI) < 0) + { + goto __exit; + } + } + rc = webclient_connect(session, URI); if (rc < 0) goto __exit; diff --git a/webclient_internal.h b/webclient_internal.h index 71b193c..28a8856 100644 --- a/webclient_internal.h +++ b/webclient_internal.h @@ -3,7 +3,7 @@ #include -#ifdef RT_USING_ESP_PSRAM +#ifdef RT_USING_PSRAM #include #define web_malloc sdram_malloc