博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
机器学习tensorflow框架初试
阅读量:6225 次
发布时间:2019-06-21

本文共 5006 字,大约阅读时间需要 16 分钟。

本文来自网易云社区

作者:汪洋

前言

新手学习可以点击参考。开始前,我们先在本地安装好 TensorFlow机器学习框架。 

  1. 首先我们在本地window下安装好python环境,约定安装3.6版本;

  2. 安装Anaconda工具集后,创建名为 tensorflow 的conda 环境:conda create -n tensorflow pip python=3.6;

  3. conda切换环境:activate tensorflow;

  4. 我们安装支持CPU的TensorFlow版本(快速):pip install --ignore-installed --upgrade tensorflow;

  5. 最后验证安装是否成功,进入 python dos命名,输入以下代码校验:

    import tensorflow as tfhello = tf.constant('Hello, TensorFlow')sess = tf.Session()print(sess.run(hello))

    输出Hello, TensorFlow,表示成功了。如果失败的话,就选择低版本重新安装如:pip install --ignore-installed --upgrade tensorflow==1.5.0。

    其它安装方式点击。

监督学习实践

官方针对新手演示了一个入门示例,可查看,本文就围绕这个教程分享。

1.分类

官方示例里讲解了分类鸢尾花问题的解决,我们想到的就是用监督学习训练机器模型。采用这种学习方式后,我们需要确定用鸢尾花的哪些特征来分类,鸢尾花的特征还是蛮多的,官方示例里用的是花萼和花瓣的长度和宽度。

鸢尾花种类非常多,官方也仅是针对三种进行分类:

expected = ['Setosa', 'Versicolor', 'Virginica']

接下来就是获取大量数据,进行预处理,官方示例里直接引用了他人整理的数据源,省略了前期数据处理步骤,前5条数据结构如下:

SepalLength SepalWidth PetalLength PetalWidth Species
0 6.4 2.8 5.6 2.2 2
1 5.0 2.3 3.3 1.0 1
2 4.9 2.5 4.5 1.7 2
3 4.9 3.1 1.5 0.1 0
4 5.7 3.8 1.7 0.1 0

说明:

  1. 最后一列代表着鸢尾花的品种,也就是说它是监督学习中的标签;

  2. 中间四列从左到右表示花萼的长度和宽度、花瓣的长度和宽度;

  3. 表格数据代表了从120个样本的数据集中抽集的5个样本;

    机器学习一般依赖数值,因此当前数据集中标签值都为数字,对应关系: 

0 1 2
Setosa Versicolor Virginica

接下来将编写代码,先复习下概念,模型指特征和标签之间的关系;训练指机器学习阶段,这个阶段模型不断优化。示例里选择的监督试学习方式,模型通过包含标签的样本进行训练。

2. 导入和解析数据集 

首先我们要获取训练集和测试集,其中训练集是训练模型的样本,测试集是评估训练后模型效果的样本。

首先设置我们选择的数据集地址

 """训练集"""TRAN_URL = "http://download.tensorflow.org/data/iris_training.csv""""测试集"""TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

使用tensorflow.keras.utils.get_file函数下载数据集,该方法第一个参数为文件名称,第二个参数为下载地址,点击)。

import tensorflow as tfdef download():    train_path = tf.keras.utils.get_file('iris_training.csv', TRAN_URL)    test_path = tf.keras.utils.get_file('iris_test.csv', TEST_URL)    return train_path, test_path

然后用pandas.read_csv函数解析下载的数据,解析后生成的格式是一个表格,然后再分成特征列表和标签列表,返回训练集和测试集

import pandas as pdCSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',                    'PetalLength', 'PetalWidth', 'Species']def load_data(y_species='Species'):    train_path, test_path = download()    train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)    train_x, train_y = train, train.pop(y_species)    test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)    test_x, test_y = test, test.pop(y_species)    return (train_x, train_y), (test_x, test_y)

3. 特征列-数值列 

我们已经获取到数据集,在tensorflow中需要将数据转换为模型(Estimator)可以使用的数据结构,这时候调用tf.feature_column模块中的函数来转换。鸢尾花例子中,需将特征数据转换为浮点数,调用tf.feature_column.numeric_column方法。

import iris_data(train_x, train_y), (test_x, test_y) = iris_data.load_data()my_feature_columns = []for key in train_x.keys():    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

其中key是 ['SepalLength' , 'SepalWidth' , 'PetalLength' , 'PetalWidth'] 其中之一。

4. 模型选择 

官方例子中选择全连接神经网络解决鸢尾花问题,用神经网络来发现特征与标签之间的复杂关系。tensorflow中,通过实例化一个Estimator类指定模型类型,这里我们使用官方提供的预创建的Estimator类,tf.estimator.DNNClassifier,此Estimator会构建一个对样本进行分类的神经网络。

classifier = tf.estimator.DNNClassifier(    feature_columns = my_feature_columns,    hidden_units = [10,10],    n_classes = 3)

feature_columns 参数指训练的特征列(这里是数值列);

hidden_units 参数定义神经网络内每个隐藏层中的神经元数量,这里设置了2个隐藏层,每个隐藏层中神经元数量都是10个;
n_classes 参数表示要预测的标签数量,这里我们需要预测3个品种;
其它参数

5. 训练模型 

上一步我们已经创建了一个学习模型,接下来将数据导入到模型中进行训练。tensorflow中,调用Estimator对象的train方法训练。

classifier.train(    input_fn = lambda:iris_data.train_input_fn(train_x, train_y, 100)    steps = 1000)

input_fn 参数表示提供训练数据的函数; steps 参数表示训练迭代次数;

在train_input_fn函数里,我们将数据转换为 train方法所需的格式。 

dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

为了保证训练效果,训练样本需随机排序。buffer_size 设置大于样本数(120),可确保数据得到充分的随机化处理。 

dataset = dataset.shuffle(1000)

为了保证训练期间,有无限量的训练样本,需调用 tf.data.Dataset.repeat。

dataset = dataset.repeat()

train方法一次处理一批样本, tf.data.Dataset.batch 方法通过组合多个样本创建一个批次,这里组合多个包含100个样本的批次。

dataset = dataset.batch(100)

6. 模型评估 

接下来我们将训练好的模型预测效果。tensorflow中,每个Estimator对象提供了evaluate方法。

eval_result = classifier.evaluate(    input_fn = lambda:iris_data.eval_input_fn(test_x, test_y, 100))

在eval_input_fn函数里,我们将数据转换为 evaluate方法所需的格式。实现跟训练一样,只是无需随机化处理和无限量重复使用测试集。

dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))dataset.batch(100);return dataset

7. 预测 

接下来将该模型对无标签样本进行预测。官方手动提供了三个无标签样本。

predict_x = {    'SepalLength': [5.1, 5.9, 6.9],    'SepalWidth': [3.3, 3.0, 3.1],    'PetalLength': [1.7, 4.2, 5.4],    'PetalWidth': [0.5, 1.5, 2.1],}

tensorflow中,每个Estimator对象提供了predict方法。

predictions = classifier.predict(    input_fn = lambda:iris_data.eval_input_fn(predict_x, labels=None, 100))

改造下eval_input_fn方法,使其能够接受 labels = none 情况

features=dict(features)if labels is None:    inputs = featureselse:    inputs = (features, labels)dataset = tf.data.Dataset.from_tensor_slices(inputs)

接下来打印下预测结果, predictions 中 class_ids表示可能性最大的品种,probabilities 表示每个品种的概率

for pred_dict in predictions:    class_id = pred_dict['class_ids'][0]    probability = pred_dict['probabilities'][class_id]    print(class_id, probability)

结果如下:

0 0.99706334
1 0.997407
2 0.97377485

结尾

通过官方例子,新手可初步了解其使用,当然更深入的使用还得学习理论和多使用API。本文是根据官方例子,作为新手重新梳理了一遍。

网易云,0成本体验20+款云产品

更多网易研发、产品、运营经验分享请访问。

 

相关文章:

【推荐】 

转载地址:http://cbyna.baihongyu.com/

你可能感兴趣的文章
超级有用的15个mysqlbinlog命令
查看>>
数据库之间转移数据
查看>>
PHP连接Mysql常用API(mysql,mysqli,pdo)区别与联系
查看>>
java中的CAS
查看>>
简单的markdown在线解析服务
查看>>
Linux基础(day44)
查看>>
Git 分支创建及使用
查看>>
MariaDB安装, Apache安装
查看>>
多线程三分钟就可以入个门了!
查看>>
从道法术三个层面理解区块链:术
查看>>
elasticsearch入门使用
查看>>
数据结构与算法4
查看>>
tomcat去掉项目名称
查看>>
微服务架构的优势与不足(一)
查看>>
分布式服务治理框架Dubbo
查看>>
小程序好的ui框架选择
查看>>
今天学习了
查看>>
Tomcat安装、配置、优化及负载均衡详解
查看>>
虹软人脸识别SDK(java+linux/window) 初试
查看>>
ppwjs之bootstrap文字排版:到标题元素
查看>>