大家好,我是正在实战各种AI项目的程序员晚枫。
今天分享一个让我处理大数据时内存占用减少90%的技术——生成器(Generator) 。
一个真实的内存爆炸事故 去年有个学员问我:"晚枫老师,我的程序处理100万条数据时直接崩溃了,报MemoryError,怎么办?"
我看了一眼他的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 def process_users (): users = [fetch_user(i) for i in range (1000000 )] results = [] for user in users: results.append(analyze(user)) return results
问题 :100万个用户对象,每个1KB,就是1GB内存!
用生成器优化后 :
1 2 3 4 5 6 7 8 9 10 11 12 13 def process_users (): def user_generator (): for i in range (1000000 ): yield fetch_user(i) results = [] for user in user_generator(): results.append(analyze(user)) return results
你可能遇到过这种情况:要处理几万、几十万条数据,程序直接卡死或报MemoryError。这时候生成器就是你的救星。
看完这篇文章,你会理解为什么生成器被称为"省内存的神器"。
问题:列表太占内存 内存占用测试 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import sysnumbers_list = [i * i for i in range (1000000 )] numbers_gen = (i * i for i in range (1000000 )) print (f"列表内存占用: {sys.getsizeof(numbers_list) / 1024 / 1024 :.2 f} MB" )print (f"生成器内存占用: {sys.getsizeof(numbers_gen) / 1024 :.2 f} KB" )
传统列表的问题 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def get_numbers_list (n ): """返回列表:一次性生成所有数据""" result = [] for i in range (n): result.append(i * i) return result numbers = get_numbers_list(1000000 ) for num in numbers[:10 ]: print (num)
解决方案:生成器 什么是生成器? 生成器是一种特殊的迭代器,它按需生成数据,而不是一次性全部生成 。
就像自动售货机,你要一个它给一个,而不是先把所有商品堆在你面前。
创建生成器的两种方式 方式1:生成器函数(yield) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def get_numbers_generator (n ): """生成器函数:按需生成数据""" for i in range (n): yield i * i gen = get_numbers_generator(1000000 ) for num in gen: print (num) if num > 100 : break
方式2:生成器表达式 1 2 3 4 5 6 7 8 9 10 11 squares_list = [i * i for i in range (1000000 )] print (f"列表: {sys.getsizeof(squares_list) / 1024 / 1024 :.2 f} MB" ) squares_gen = (i * i for i in range (1000000 )) print (f"生成器: {sys.getsizeof(squares_gen)} bytes" ) for square in squares_gen: print (square)
两种方式的对比 特性 生成器函数 生成器表达式 语法 def func(): yield x(x for x in iterable)复杂度 可以有复杂逻辑 适合简单表达式 可读性 更易读 简洁 灵活性 高 低
1 2 3 4 5 6 7 8 9 10 11 def process_data (filename ): with open (filename, 'r' ) as f: for line in f: cleaned = line.strip().lower() if cleaned and not cleaned.startswith('#' ): yield cleaned cleaned = (line.strip().lower() for line in open ('data.txt' ))
yield的工作原理 状态保存与恢复 1 2 3 4 5 6 7 8 9 10 11 12 13 14 def simple_generator (): print ("开始" ) yield 1 print ("继续" ) yield 2 print ("结束" ) yield 3 gen = simple_generator() print (next (gen)) print (next (gen)) print (next (gen))
执行流程图 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 调用生成器函数 ↓ 创建生成器对象(不执行函数体) ↓ next()被调用 ↓ 执行到第一个yield ↓ 返回yield后的值,暂停 ↓ next()再次被调用 ↓ 从暂停处继续执行 ↓ 执行到下一个yield ↓ ...重复... ↓ 函数结束,抛出StopIteration
yield的高级用法 yield from(Python 3.3+) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 def chain (*iterables ): for iterable in iterables: for item in iterable: yield item def chain (*iterables ): for iterable in iterables: yield from iterable result = list (chain([1 , 2 ], [3 , 4 ], [5 ])) print (result) def flatten (nested ): """展平嵌套结构""" for item in nested: if isinstance (item, (list , tuple )): yield from flatten(item) else : yield item nested = [1 , [2 , 3 ], [4 , [5 , 6 ]]] print (list (flatten(nested)))
yield作为表达式(协程基础) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def accumulator (): """累加器(协程示例)""" total = 0 while True : value = yield total if value is None : break total += value acc = accumulator() next (acc) print (acc.send(10 )) print (acc.send(5 )) print (acc.send(3 )) acc.close()
性能对比:列表 vs 生成器 内存占用测试 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 import sysimport tracemallocdef test_memory (): """测试内存占用""" tracemalloc.start() snapshot1 = tracemalloc.take_snapshot() numbers_list = [i * i for i in range (1000000 )] snapshot2 = tracemalloc.take_snapshot() list_memory = sum (stat.size for stat in snapshot2.compare_to(snapshot1, 'filename' )) snapshot3 = tracemalloc.take_snapshot() numbers_gen = (i * i for i in range (1000000 )) snapshot4 = tracemalloc.take_snapshot() gen_memory = sum (stat.size for stat in snapshot4.compare_to(snapshot3, 'filename' )) print (f"列表内存: {list_memory / 1024 / 1024 :.2 f} MB" ) print (f"生成器内存: {gen_memory / 1024 :.2 f} KB" ) print (f"节省: {(list_memory - gen_memory) / list_memory * 100 :.1 f} %" ) test_memory()
执行时间对比 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 import timedef test_speed (): """测试执行时间""" start = time.time() squares_list = [i * i for i in range (1000000 )] result_list = sum (squares_list[:10 ]) list_time = time.time() - start start = time.time() squares_gen = (i * i for i in range (1000000 )) result_gen = sum (next (squares_gen) for _ in range (10 )) gen_time = time.time() - start print (f"列表方式: {list_time:.4 f} s, 结果: {result_list} " ) print (f"生成器方式: {gen_time:.4 f} s, 结果: {result_gen} " ) print (f"生成器快 {list_time / gen_time:.1 f} 倍" ) test_speed()
链式操作对比 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 import timedef process_data_list (data ): """列表方式:三次遍历,三次内存占用""" filtered = [x for x in data if x > 500000 ] transformed = [x * 2 for x in filtered] result = sum (transformed) return result def process_data_gen (data ): """生成器方式:一次遍历,零额外内存""" result = sum ( x * 2 for x in data if x > 500000 ) return result data = range (1000000 ) start = time.time() result1 = process_data_list(data) list_time = time.time() - start start = time.time() result2 = process_data_gen(data) gen_time = time.time() - start print (f"列表方式: {list_time:.4 f} s" )print (f"生成器方式: {gen_time:.4 f} s" )print (f"结果相同: {result1 == result2} " )
实战案例 案例1:读取大文件 1 2 3 4 5 6 7 8 9 10 11 12 13 14 def read_file_bad (filename ): with open (filename, 'r' ) as f: return f.readlines() def read_file_good (filename ): with open (filename, 'r' , encoding='utf-8' ) as f: for line in f: yield line.strip() for line in read_file_good('huge_file.txt' ): process(line)
案例2:无限序列 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 def fibonacci (): """生成无限的斐波那契数列""" a, b = 0 , 1 while True : yield a a, b = b, a + b def count (start=0 , step=1 ): """无限计数器""" n = start while True : yield n n += step def cycle (iterable ): """无限循环""" while True : for item in iterable: yield item fib = fibonacci() for _ in range (10 ): print (next (fib), end=' ' ) print ()counter = count(10 , 2 ) for _ in range (5 ): print (next (counter), end=' ' ) print ()colors = cycle(['red' , 'green' , 'blue' ]) for _ in range (7 ): print (next (colors), end=' ' )
案例3:数据管道 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 def read_log_file (filename ): """读取日志文件""" with open (filename, 'r' , encoding='utf-8' ) as f: for line in f: yield line.strip() def filter_errors (lines ): """过滤错误日志""" for line in lines: if 'ERROR' in line: yield line def parse_timestamp (lines ): """解析时间戳""" for line in lines: timestamp = line[1 :20 ] yield {'time' : timestamp, 'message' : line} def batch (lines, size ): """批量分组""" batch = [] for line in lines: batch.append(line) if len (batch) == size: yield batch batch = [] if batch: yield batch logs = read_log_file('app.log' ) errors = filter_errors(logs) parsed = parse_timestamp(errors) batches = batch(parsed, 100 ) for batch in batches: save_to_database(batch)
案例4:分页查询 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def paginated_query (db, table, page_size=100 ): """分页查询生成器""" offset = 0 while True : query = f"SELECT * FROM {table} LIMIT {page_size} OFFSET {offset} " results = db.execute(query) if not results: break for row in results: yield row offset += page_size for row in paginated_query(db, 'users' ): process_user(row)
案例5:文件搜索 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 from pathlib import Pathdef find_files (directory, pattern='*' ): """递归查找文件""" directory = Path(directory) for path in directory.rglob(pattern): if path.is_file(): yield path def find_by_content (directory, keyword ): """按内容搜索文件""" for filepath in find_files(directory, '*.py' ): with open (filepath, 'r' , encoding='utf-8' ) as f: for line_num, line in enumerate (f, 1 ): if keyword in line: yield { 'file' : str (filepath), 'line_num' : line_num, 'line' : line.strip() } for result in find_by_content('./my_project' , 'TODO' ): print (f"{result['file' ]} :{result['line_num' ]} : {result['line' ]} " )
案例6:数据流处理 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 import timedef data_stream (): """模拟实时数据流""" import random while True : yield { 'timestamp' : time.time(), 'value' : random.random() * 100 } time.sleep(0.1 ) def sliding_window (stream, size=10 ): """滑动窗口""" window = [] for item in stream: window.append(item) if len (window) > size: window.pop(0 ) yield window.copy() def moving_average (windows ): """移动平均""" for window in windows: if window: avg = sum (item['value' ] for item in window) / len (window) yield avg stream = data_stream() windows = sliding_window(stream, size=10 ) averages = moving_average(windows) for avg in averages: print (f"移动平均: {avg:.2 f} " )
生成器 vs 迭代器 迭代器协议 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 class CounterIterator : """自定义迭代器""" def __init__ (self, start, end ): self.current = start self.end = end def __iter__ (self ): return self def __next__ (self ): if self.current >= self.end: raise StopIteration result = self.current self.current += 1 return result counter = CounterIterator(0 , 5 ) for num in counter: print (num)
生成器实现同样功能 1 2 3 4 5 6 7 8 9 10 11 def counter_generator (start, end ): """生成器:更简洁""" current = start while current < end: yield current current += 1 counter = counter_generator(0 , 5 ) for num in counter: print (num)
对比总结 特性 生成器 迭代器 创建方式 yield或生成器表达式__iter__和__next__方法代码复杂度 简单(几行) 较复杂(需要类) 内存占用 极低 低 可重用性 一次性(用完就没了) 一次性 状态保存 自动 手动管理 异常处理 自动抛出StopIteration 手动抛出
结论 :能用生成器就用生成器,代码简洁又高效。
常用内置生成器函数 enumerate 1 2 3 4 5 6 7 8 9 10 11 12 13 14 fruits = ['apple' , 'banana' , 'cherry' ] for i in range (len (fruits)): print (i, fruits[i]) for i, fruit in enumerate (fruits): print (i, fruit) for i, fruit in enumerate (fruits, start=1 ): print (f"{i} . {fruit} " )
zip 1 2 3 4 5 6 7 8 9 10 11 12 names = ['Alice' , 'Bob' , 'Charlie' ] ages = [25 , 30 , 35 ] cities = ['Beijing' , 'Shanghai' , 'Guangzhou' ] for name, age, city in zip (names, ages, cities): print (f"{name} , {age} 岁, {city} " ) user_dict = dict (zip (names, ages)) print (user_dict)
map 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 numbers = [1 , 2 , 3 , 4 , 5 ] squared_list = [x ** 2 for x in numbers] squared_map = map (lambda x: x ** 2 , numbers) print (list (squared_map)) list1 = [1 , 2 , 3 ] list2 = [10 , 20 , 30 ] result = map (lambda x, y: x + y, list1, list2) print (list (result))
filter 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 numbers = range (20 ) evens_list = [x for x in numbers if x % 2 == 0 ] evens_filter = filter (lambda x: x % 2 == 0 , numbers) print (list (evens_filter)) items = [0 , 1 , None , False , '' , 'hello' , []] truthy = filter (None , items) print (list (truthy))
range 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 r = range (1000000 ) print (sys.getsizeof(r)) for i in range (10 ): print (i) for i in range (0 , 10 , 2 ): print (i) for i in range (10 , 0 , -1 ): print (i)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 import itertoolsfor i in itertools.count(10 , 2 ): if i > 20 : break print (i) colors = itertools.cycle(['red' , 'green' , 'blue' ]) for _ in range (7 ): print (next (colors)) for item in itertools.repeat('hello' , 3 ): print (item) combined = itertools.chain([1 , 2 ], [3 , 4 ], [5 , 6 ]) print (list (combined)) result = itertools.islice(range (100 ), 10 , 20 ) print (list (result)) result = itertools.takewhile(lambda x: x < 5 , range (10 )) print (list (result)) result = itertools.dropwhile(lambda x: x < 5 , range (10 )) print (list (result)) perms = itertools.permutations([1 , 2 , 3 ], 2 ) print (list (perms)) combs = itertools.combinations([1 , 2 , 3 , 4 ], 2 ) print (list (combs))
生成器工具函数 消费生成器 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 squares = (i * i for i in range (10 )) result = list (squares) print (result) import itertoolssquares = (i * i for i in range (100 )) first_10 = list (itertools.islice(squares, 10 )) print (first_10)squares = (i * i for i in range (10 )) fifth = next (itertools.islice(squares, 4 , None )) print (fifth) squares = (i * i for i in range (10 )) total = sum (squares) print (total) squares = (i * i for i in range (10 )) print (max (squares))
重置生成器 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 gen = (i * i for i in range (5 )) print (list (gen)) print (list (gen)) def squares (n ): return (i * i for i in range (n)) gen1 = squares(5 ) gen2 = squares(5 ) data = list ((i * i for i in range (5 )))
tee:复制生成器 1 2 3 4 5 6 7 8 9 10 import itertoolsgen = (i * i for i in range (5 )) gen1, gen2, gen3 = itertools.tee(gen, 3 ) print (list (gen1)) print (list (gen2)) print (list (gen3))
避坑指南 坑1:生成器只能用一次 1 2 3 4 5 6 7 8 9 10 11 12 gen = (i * i for i in range (5 )) for num in gen: print (num) for num in gen: print (num) data = list (gen)
坑2:生成器不支持索引 1 2 3 4 5 6 7 8 gen = (i * i for i in range (10 )) import itertoolsfifth = next (itertools.islice(gen, 5 , None )) print (fifth)
坑3:生成器不支持len() 1 2 3 4 5 6 7 8 9 10 11 gen = (i * i for i in range (10 )) length = len (list (gen)) gen = (i * i for i in range (10 )) count = sum (1 for _ in gen) print (count)
坑4:生成器不能切片 1 2 3 4 5 6 7 8 gen = (i * i for i in range (10 )) import itertoolsresult = list (itertools.islice(gen, 2 , 5 )) print (result)
坑5:嵌套生成器的陷阱 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 matrix = [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]] flat = (item for row in matrix for item in row) print (list (flat)) flat = [] for row in matrix: for item in row: flat.append(item) even = (item for row in matrix for item in row if item % 2 == 0 ) print (list (even))
推荐:AI Python零基础实战营 想系统学习Python高级特性?
课程内容:
✅ Python基础语法 ✅ 生成器与迭代器详解 ✅ 内存优化技巧 ✅ 大数据处理实战 🎁 限时福利 :送《Python编程从入门到实践》实体书
👉 点击了解详情
相关阅读 PS:生成器是Python的高级特性之一,掌握它,你就能优雅地处理大数据。记住:内存不够,生成器来凑!
📚 推荐教材 主教材 :《Python 编程从入门到实践(第 3 版)》
📚 推荐:Python 零基础实战营 系统学习Python,推荐这个免费入门课程 👇
特点 说明 🎯 专为0基础设计 门槛低,上手快 📹 配套视频讲解 配合文章学习效果更好 💬 专属答疑群 遇到问题有人带 🎁 实体书赠送 优秀学员送《Python编程从入门到实践》
👉 点击免费领取 Python 零基础实战营
💬 联系我 主营业务 :AI 编程培训、企业内训、技术咨询
🎓 AI 编程实战课程 想系统学习 AI 编程?程序员晚枫的 AI 编程实战课 帮你从零上手!