在浏览器中运行 TensorFlow.js 来训练模型并给出预测结果(Iris 数据集)
创始人
2024-03-15 02:08:25
0

文章目录

  • 开发环境
  • 构建第一个 TensorFlow.js 模型
  • 构建鸢尾花数据集分类器
  • References


在 《TensorFlow Lite 是什么?用 TensorFlow Lite 来转换模型(附代码)》中我们已经介绍了可以帮助 TensorFlow 模型在移动设备以及嵌入式设备中运行的 TensorFlow Lite,TensorFlow 生态系统中还包括 TensorFlow.js,它可以帮助我们使用现成的 JavaScript 模型或转换 Python TensorFlow 模型以在浏览器中或 Node.js 下运行。

下面这张图总结了整个 TensorFlow 的生态系统:

在这里插入图片描述

和 TensorFlow Lite 不同的是,TensorFlow.js 还可以用来训练模型。它可以让我们在 JavaScript 中使用类似 keras 的代码语法,非常友好。

开发环境

有能力的朋友可以在任何 web/JavaScript 开发环境下来进行尝试,我们这里直接使用 brackets 官网给出的线上代码编辑器 Phoenix (进入官网后就会自动弹出提示)来进行演示。

进入 Phoenix 后会显示如下画面:

在这里插入图片描述

我们直接将 index.html 文件的内容清空,并先加入以下的大框架:



First HTML Page

构建第一个 TensorFlow.js 模型

以及 标签之间,我们添加下面的 script 标签来指定 TensorFlow.js 库的位置:


之后,我们在第一个 script 标签后添加第二个 script 标签,里面要定义我们的模型,语法和 python 非常相似,但要记得在结尾添加分号:

    >

First HTML Page

点击 File -> Save File,代码会自动运行,如果已经保存,可以直接点击预览页面左上方的刷新按钮,稍等几秒(模型训练),会弹出以下对话框:

在这里插入图片描述

这就是输入数据为 [10] 时模型给出的预测结果!如果我们想查看模型每个 epoch 之后打印的训练损失,直接按下快捷键 Ctrl-Shift-I,并在弹出的面板上方选择 Console,就会有如下结果:

在这里插入图片描述


下面我们训练一个稍微复杂点的模型。


构建鸢尾花数据集分类器

鸢尾花数据集(.csv)共有 150 条数据,每条数据有 4 个特征(sepal length、sepal width、petal length、petal width),对应三种鸢尾花(setosa、versicolor、virginical)。鸢尾花数据集很容易找到,也可以从我这里下载:Iris 鸢尾花数据集(.csv 格式)。

通过常规的机器学习方法,我们可以对数据集做一些可视化,加深认识:

import pandas as pd
import numpy as npimport seaborn as sns
import matplotlib.pyplot as plt
df = pd.read_csv('../input/iris-dataset/iris.csv')
df.head(5)
"""sepal_length  sepal_width  petal_length  petal_width species
0           5.1          3.5           1.4          0.2  setosa
1           4.9          3.0           1.4          0.2  setosa
2           4.7          3.2           1.3          0.2  setosa
3           4.6          3.1           1.5          0.2  setosa
4           5.0          3.6           1.4          0.2  setosa
"""

我们可以通过 sns.pairplot() 画出两两特征之间的关系,且用种类进行划分:

sns.pairplot(df, kind = 'scatter', hue = 'species')
plt.show()

在这里插入图片描述
对角线上为每个种类在某个特征上的分布图,非对角线上则是两个特征选取不同值时对应的鸢尾花种类。

下面我们就开始在 Phoenix 中进行训练吧!

我们点击左上角的新建项目,在本地选择路径,创建新项目,然后将我们下载的 iris 数据集拖入我们项目保存的路径。
在这里插入图片描述

和之前一样,我们先添加以下大框架:



Iris Classifier

我们可以使用 TensorFlow.js 的 tf.data.csv 来加载 CSV 文件,且可以通过它来指定标签对应的列:

    

species 对应的是种类名称的字符串,我们需要先将它转换为数值。我们这里使用独热编码来转换标签:

        const convertedData = trainingData.map(({xs, ys}) => {const labels = [ys.species == 'setosa' ? 1 : 0,ys.species == 'virginica' ? 1: 0,ys.species == 'versicolor' ? 1 : 0]return {xs: Object.values(xs), ys: Object.values(labels)};}).batch(10);

上述代码会将 ‘setosa’ 编码为 [1, 0, 0],将 ‘virginica’ 编码为 [0, 1, 0],而将 ‘versicolor’ 编码为 [0, 0, 1],并返回和之前一样的数据集,除了 species 列的字符串已经被编码为独热向量。

下面我们定义并编译模型,输入层形状为输入特征数(列数减 1),输出层有 3 个神经元:

        const numOfFeatures = (await trainingData.columnNames()).length - 1;const model = tf.sequential();model.add(tf.layers.dense({inputShape: [numOfFeatures],activation: "sigmoid", units: 5}));model.add(tf.layers.dense({activation: "softmax", units: 3}));model.compile({loss: "categoricalCrossentropy",optimizer: tf.train.adam(0.06)});

和之前不同,我们的数据是以数据集的形式组织的,所以训练时我们要使用 fitDataset 方法:

        await model.fitDataset(convertedData,{epochs:100,callbacks:{onEpochEnd: async(epoch, logs) =>{console.log("Epoch: " + epoch + " Loss: " + logs.loss);}}});

如果要测试模型,我们可以使用之前用到的 tensor2d 来创建一个输入数据:

const testVal = tf.tensor2d([4.4, 2.9, 1.4, 0.2], [1, 4]);
alert(model.predict(testVal));

我们将完整代码给出:


Iris Classifier

运行之后,会弹出如下结果:

在这里插入图片描述
我们可以对结果进一步优化,让其显示预测的具体种类:

const testVal = tf.tensor2d([4.4, 2.9, 1.4, 0.2], [1, 4]);
const prediction = model.predict(testVal);
const pIndex = tf.argMax(prediction, axis=1).dataSync();const classNames = ["Setosa", "Virginica", "Versicolor"];
alert(classNames[pIndex]);

在这里插入图片描述

References

AI and Machine Learning for Coders by Laurence Moroney.

相关内容

热门资讯

AWSECS:访问外部网络时出... 如果您在AWS ECS中部署了应用程序,并且该应用程序需要访问外部网络,但是无法正常访问,可能是因为...
AWSElasticBeans... 在Dockerfile中手动配置nginx反向代理。例如,在Dockerfile中添加以下代码:FR...
银河麒麟V10SP1高级服务器... 银河麒麟高级服务器操作系统简介: 银河麒麟高级服务器操作系统V10是针对企业级关键业务...
北信源内网安全管理卸载 北信源内网安全管理是一款网络安全管理软件,主要用于保护内网安全。在日常使用过程中,卸载该软件是一种常...
AWR报告解读 WORKLOAD REPOSITORY PDB report (PDB snapshots) AW...
AWS管理控制台菜单和权限 要在AWS管理控制台中创建菜单和权限,您可以使用AWS Identity and Access Ma...
​ToDesk 远程工具安装及... 目录 前言 ToDesk 优势 ToDesk 下载安装 ToDesk 功能展示 文件传输 设备链接 ...
群晖外网访问终极解决方法:IP... 写在前面的话 受够了群晖的quickconnet的小水管了,急需一个新的解决方法&#x...
不能访问光猫的的管理页面 光猫是现代家庭宽带网络的重要组成部分,它可以提供高速稳定的网络连接。但是,有时候我们会遇到不能访问光...
Azure构建流程(Power... 这可能是由于配置错误导致的问题。请检查构建流程任务中的“发布构建制品”步骤,确保正确配置了“Arti...